Skip to content

feat: add Gemma4 dense and MoE custom implementations#2190

Open
hallerite wants to merge 7 commits intomainfrom
feat/gemma4
Open

feat: add Gemma4 dense and MoE custom implementations#2190
hallerite wants to merge 7 commits intomainfrom
feat/gemma4

Conversation

@hallerite
Copy link
Copy Markdown
Member

@hallerite hallerite commented Apr 3, 2026

Summary

  • Bump transformers to 5.5.0 and vLLM to 0.19.0 for Gemma4 support
  • Custom trainer implementations for Gemma4-31B (dense) and Gemma4-26B-A4B (MoE)
  • Unified registration: one Gemma4ForCausalLM handles both via enable_moe_block config flag
  • Adapt server.py for vLLM 0.19 API changes
  • Remove obsolete vLLM monkey-patches fixed upstream in 0.18/0.19

Gemma4 architecture features

  • Hybrid sliding window + global attention with dual RoPE (theta=10K sliding, theta=1M + partial_rotary_factor=0.25 global)
  • K=V sharing on global attention layers (no v_proj)
  • QKV norms (q/k with learnable scale, v without)
  • Attention scaling = 1.0
  • Per-layer learnable scalar and scaled embeddings
  • Logit softcapping (tanh at 30.0)
  • MoE: shared expert + sparse experts in parallel, custom router with norm+scale+per_expert_scale
  • MoE weight conversion: HF fused gate_up_proj -> PrimeRL w1/w2/w3

vLLM 0.19 changes

  • Removed run_api_server_worker_proc import + monkey-patch (deleted upstream)
  • Updated build_app signature (new model_config param)
  • Simplified custom_init_app_state -- no longer duplicates constructor kwargs, just swaps the class on the already-constructed serving_chat object

Removed monkey-patches (fixed upstream)

  • monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode -- fixed in vLLM 0.18 (PR #30220)
  • monkey_patch_tokenizer_thread_safety -- fixed in vLLM 0.19 (PR #34789, single-threaded tokenizer executor)
  • monkey_patch_hermes_tool_parser_thread_safety -- fixed in vLLM 0.19 (rewritten, no more tokenizer calls in init)

Patches to revisit

  • monkey_patch_load_lora_adapter / monkey_patch_LRUCacheWorkerLoRAManager -- may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
  • monkey_patch_dp_engine_core_pause_resume_deadlock -- partially fixed in vLLM 0.19 (PR #37024) but both critical bugs still present
  • monkey_patch_harmony_stop_token_propagation -- upstream issue still open ([Bug]: [gpt oss 20b] [tool_call] Unexpected token 12606 while expecting start token 200006 vllm-project/vllm#22519)
  • monkey_patch_tokenize_params_validation -- still overly restrictive upstream
  • monkey_patch_fused_moe_lora_dp -- DP-specific LoRA+MoE bugs still present
  • monkey_patch_minimax_m2_for_lora -- gate dtype + key remapping still needed

Compat fixes

Verified

  • 365 unit tests pass (5 new Gemma4: forward, backward, no-KV-sharing, MoE forward, MoE backward)
  • SFT trains on both dense and MoE (5 steps each)
  • RL e2e completes 2 steps on reverse-text -- dense mismatch KL ~0.001, MoE ~0.008

Note

High Risk
High risk due to a large new custom model implementation (Gemma4 dense+MoE, including VLM handling and weight conversion) plus dependency upgrades to transformers/vllm and corresponding inference server behavior changes.

Overview
Adds first-class Gemma4 support by upgrading dependencies (transformers 5.5-based rev, vllm >=0.19) and introducing a new custom Gemma4 trainer implementation that covers both dense and MoE variants via a unified Gemma4ForCausalLM (plus MoE weight-conversion helpers and VLM text/vision composition).

Updates inference integration for vLLM 0.19 by simplifying init_app_state customization (swap-in OpenAIServingChatWithTokens instead of re-constructing it), adjusting build_app signature usage, extending tool parser auto-resolution to Gemma4, and removing now-obsolete monkey patches; also changes error handling to let FastAPI’s global GenerationError handler run.

Improves training ergonomics/compatibility: adds a ring_flash_attn + transformers compatibility shim loaded early, adds chat_template_kwargs plumbing for SFT datasets/configs (including new Gemma4/Qwen test configs), tightens/guards perf accounting for optional MoE fields, updates RoPE config validation calls, and adds GPU unit tests covering Gemma4 forward/backward for dense and MoE.

Written by Cursor Bugbot for commit 84f98f4. This will update automatically on new commits. Configure here.

@hallerite hallerite force-pushed the feat/gemma4 branch 4 times, most recently from 2aeebec to 150a47a Compare April 3, 2026 14:07
Bump transformers to 5.5.0 and vLLM to 0.19.0 for Gemma4 support.

Custom trainer implementations for both Gemma4-31B (dense) and
Gemma4-26B-A4B (MoE) with:
- Hybrid sliding window + global attention with dual RoPE
- K=V sharing on global attention layers
- QKV norms (q/k with scale, v without)
- Attention scaling = 1.0
- Per-layer learnable scalar and scaled embeddings
- Logit softcapping (tanh at 30.0)
- MoE: shared expert + sparse experts in parallel, custom router
- MoE weight conversion (HF fused gate_up_proj → PrimeRL w1/w2/w3)

Unified registration: one Gemma4ForCausalLM handles both dense and MoE
via the enable_moe_block config flag.

Also includes:
- ring_flash_attn compat shim for transformers 5.5 (removed symbol)
- vLLM 0.19 server.py API adaptation (removed run_api_server_worker_proc)
- perf.py None-guard for optional MoE config fields
- Test fix for Qwen3.5 VLM strict dataclass validation

Verified: 365 unit tests pass, SFT trains on both dense and MoE,
RL e2e completes 2 steps on reverse-text with mismatch KL ~0.001.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@hallerite hallerite marked this pull request as ready for review April 3, 2026 14:53
serving_chat = object.__new__(OpenAIServingChatWithTokens)
serving_chat.__dict__.update(original_chat.__dict__)
state.openai_serving_chat = serving_chat
state.openai_serving_chat_with_tokens = serving_chat
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing state attribute causes AttributeError on fallback path

Medium Severity

The new custom_init_app_state only sets state.openai_serving_chat_with_tokens inside the if "generate" in supported_tasks and state.openai_serving_chat is not None block. When this condition is false, the attribute is never set. The old code always set it (to None in the else case). The chat_with_tokens dependency at line 179 directly accesses request.app.state.openai_serving_chat_with_tokens, which will raise AttributeError on Starlette's State object if the attribute was never assigned.

Fix in Cursor Fix in Web

# Split gate_up into w1 (gate) and w3 (up)
state_dict[f"model.layers.{i}.experts.w1"] = gate_up_proj[:, :moe_dim, :]
state_dict[f"model.layers.{i}.experts.w3"] = gate_up_proj[:, moe_dim:, :]
state_dict[f"model.layers.{i}.experts.w2"] = down_proj
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF-to-PrimeRL conversion silently skips per-expert format

Medium Severity

convert_hf_layer_to_prime only handles the fused experts.gate_up_proj format, but convert_prime_layer_to_hf produces per-expert format (experts.{j}.gate_proj.weight etc.). The is_hf_state_dict check correctly recognizes both formats, but the conversion function silently returns without converting when it encounters per-expert keys. A round-trip (PrimeRL → HF → PrimeRL) would leave MoE weights unconverted.

Additional Locations (1)
Fix in Cursor Fix in Web


inv_freq, attn_scaling = rope_init_fn(self.config, **kwargs)
getattr(rotary_emb, f"{layer_type}_inv_freq").copy_(inv_freq)
rotary_emb.attention_scaling[layer_type] = attn_scaling
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate MoE model class when unified registration exists

Low Severity

Gemma4MoeModel and Gemma4MoeForCausalLM duplicate nearly all backbone logic from Gemma4Model and Gemma4ForCausalLM, which already handle MoE via the enable_moe_block config flag. The MoE variants (Gemma4MoeModel.forward, Gemma4MoeForCausalLM.forward, init_buffers_post_meta) are near-identical copies. The standalone Gemma4MoeForCausalLM isn't registered in the model registry, making it unused dead code that increases maintenance burden.

Fix in Cursor Fix in Web

hallerite and others added 2 commits April 3, 2026 23:53
…e kwargs

- Register Gemma4 in VLM_REGISTRY and _CUSTOM_VLM_MAPPING for composite config handling
- Add Gemma4VLMModel wrapper (HF vision tower + custom text model)
- Make Gemma4ForCausalLM handle both text-only and VLM configs (like Qwen3.5)
- Add VLM key remapping helpers for weight conversion
- SDPA fallback in Gemma4Attention for global layers with head_dim=512
  (FlashAttention caps at 256, Gemma4 global_head_dim=512)
- Add chat_template_kwargs to SFTDataConfig for models with thinking-aware
  templates (Gemma4-it requires enable_thinking=true for prefix stability)
- Fix missing state.openai_serving_chat_with_tokens on fallback path in server.py

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 4 total unresolved issues (including 3 from previous reviews).

Autofix Details

Bugbot Autofix prepared a fix for the issue found in the latest run.

  • ✅ Fixed: Missing CHANGELOG entry for new config field
    • Added a CHANGELOG.md entry documenting the new sft.data.chat_template_kwargs SFTDataConfig field and its usage.

Create PR

Or push these changes by commenting:

@cursor push 48d213a6e9
Preview (48d213a6e9)
diff --git a/CHANGELOG.md b/CHANGELOG.md
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,7 @@
 
 Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).
 
+- **`sft.data.chat_template_kwargs`**: Added optional passthrough kwargs for `tokenizer.apply_chat_template()` in `SFTDataConfig` (default: `None`). Use this to pass template-specific options such as `{"enable_thinking": true}` for thinking-aware chat templates. (2026-04-03)
 - **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31)
 - **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31)
 - **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27)

This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

description="Extra keyword arguments passed to tokenizer.apply_chat_template(). "
"E.g. {'enable_thinking': true} for models with thinking-aware templates (Gemma4, Qwen3)."
),
] = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CHANGELOG entry for new config field

Low Severity

A new chat_template_kwargs field is added to SFTDataConfig in src/prime_rl/configs/sft.py, but CHANGELOG.md has no corresponding entry. Per the project rule, any PR modifying configuration structures or usage patterns in config files must update the changelog.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

hallerite and others added 3 commits April 3, 2026 18:46
Register google/gemma-4-{26B-A4B,31B}-it with the "gemma4" tool call
parser (vLLM 0.19 native). Reasoning parser ("gemma4") is already
supported via inference.reasoning_parser config field.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Replace custom torch-native Gemma4RMSNorm with the shared
quack-accelerated RMSNorm. The KL mismatch difference was negligible
(0.0012 vs 0.0010) and consistency with other models is more valuable.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

There are 6 total unresolved issues (including 4 from previous reviews).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

out = torch.cat(out_splits, dim=0)
if num_padding > 0:
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
return out
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float indices from histc break expert for-loop fallback

Medium Severity

_run_experts_for_loop receives num_tokens_per_expert from torch.histc, which returns a float tensor. Calling .tolist() yields Python floats, and sum(num_tokens_per_expert_list) returns a float. Using this as a slice index in x[:sum(...)] raises TypeError because Python slice indices must be integers. The torch.split call with float section sizes also fails. This fallback path (when use_grouped_mm=False) is broken.

Fix in Cursor Fix in Web

# Flash attention expects [total_tokens, heads, dim]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant double transpose is a no-op in attention

Low Severity

In Gemma4Attention.forward, query/key/value states are transposed from [B, T, H, D] to [B, H, T, D] on lines 295/304/307, then immediately transposed back to [B, T, H, D] on lines 310-312 before being indexed with [0]. The two transposes cancel each other out entirely, making the first set unnecessary.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant