Skip to content

Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch#19213

Open
mergennachin wants to merge 11 commits intomainfrom
gemma4-31b-quant-framework
Open

Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch#19213
mergennachin wants to merge 11 commits intomainfrom
gemma4-31b-quant-framework

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented Apr 29, 2026

Text-only export of Gemma 4 31B-IT to ExecuTorch with INT4/INT8 weight quantization. Quantized weights use torchao's native tensor subclasses (Int4Tensor, IntxUnpackedToInt8Tensor) for serialization, aligning with the torchao ecosystem.

quant/ package separates quantization into independent modules:

  • recipe.py: declarative QuantRecipe with regex FQN matching and per-layer overrides
  • quantize.py: quantize_weight / dequantize_weight / quantize_model — returns torchao subclasses directly. 8-bit fully delegates to IntxUnpackedToInt8Tensor.from_hp (min_max and HQQ). 4-bit uses torchao primitives + manual Int4Tensor construction (pending mslk availability for from_hp)
  • pack.py: pack_model (bulk, groups by parent for MoE) and pack_one (streaming). Dispatches via isinstance(_, TorchAOBaseTensor)
  • pack_cuda.py: converts Int4Tensor to IntxUnpackedToInt8Tensor (int4 values unpacked to int8) and passes INT8 IntxUnpackedToInt8Tensor through unchanged. No CUDA required for packing — the CUDA-specific tinygemm conversion is a source transform applied at export time
  • gguf.py: unpack Q4_K/Q6_K GGUF blocks directly to Int4Tensor/IntxUnpackedToInt8Tensor, with streaming iterator

Serialization uses torchao's safetensors integration (torchao.prototype.safetensors) — no custom format. Checkpoints are compatible with torchao's save_pretrained/load_pretrained and can be loaded by vLLM.

This framework is designed to be promoted and reused for Qwen 3.5 MoE and other models — adding a new model requires only a QuantRecipe and optionally a custom packer.

Quantization recipes: "default" (INT4 min_max linears + INT8 per-axis embedding) and "sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ asymmetric elsewhere).

Dual-path INT4 linear dispatch: IntxUnpackedToInt8Tensor's F.linear dispatch dequantizes to bf16 and calls cuBLAS, optimal for prefill (12x faster than tinygemm at T=2048). For decode, a model-agnostic source transform (backends/cuda/transforms/int4_linear_dispatch.py) converts to Int4TilePackedTo4dTensor (tinygemm), optimal for M=1. Export flow: prefill first (dequant+cuBLAS), then tinygemm transform, then decode export. inference.py applies the tinygemm transform for fast eager decode.

Split-K flash-decoding: ReplaceEdgeOpWithTritonOpPass in the CUDA backend selects triton::sdpa_decode_splitk for SDPA nodes where L_q=1 and L_kv exceeds 2048. At 128K context, full-attention decode SDPA improves from 15.7ms/layer to 0.7ms/layer (22x). Sliding-window layers (ring buffer <= 2048) use standard triton::sdpa. No model code changes — the pass inspects Q/K shapes in the exported graph automatically.

GGUF support: inference.py --gguf and export.py --gguf load community-quantized GGUF files directly. Tied embed/lm_head is untied — embedding dequantized to bf16 for gather, lm_head keeps INT4 for matmul.

Ring-buffer KV cache: Sliding window layers use RingKVCache (2x window) instead of flat max_seq_len buffers. The C++ runner chunks long prompts automatically via get_max_prefill_chunk metadata. Chunked prefill produces identical logits to sequential (verified by test).

Includes: C++ runner with BOS/EOS handling, chunked prefill, and #ifdef guards for non-CUDA builds; eager inference with torch.compile; unit and integration tests across quant/tests/, tests/, and backends/cuda/tests/.

Copilot AI review requested due to automatic review settings April 29, 2026 21:06
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 29, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19213

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 5 Unrelated Failures

As of commit 5b54f50 with merge base 8a397b4 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 29, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new Gemma 4 31B-IT example pipeline for ExecuTorch (CUDA backend), including a packing-agnostic quantization format + recipes, CUDA packers, export/inference scripts, a C++ runner, and CI coverage.

Changes:

  • Introduces examples/models/gemma4_31b/quant/ with recipe → quantize → serialize → pack flow plus unit tests.
  • Adds Gemma 4 31B model implementation with hybrid attention and a sliding-window KV cache, plus export + eager inference entrypoints.
  • Adds CUDA runner build targets and runs Gemma 4 31B tests in the CUDA GitHub Actions workflow.

Reviewed changes

Copilot reviewed 28 out of 28 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
examples/models/gemma4_31b/test_pipeline.py CPU-only integration tests for quantize/save/load roundtrip and tiny checkpoint fixtures.
examples/models/gemma4_31b/test_cuda_pipeline.py CUDA integration tests for pack/infer/export on a tiny model.
examples/models/gemma4_31b/sampler.py GPU-side Gumbel-max sampler used by the exported model.
examples/models/gemma4_31b/quantize_and_save.py CLI to quantize HF checkpoints and write packing-agnostic safetensors bundles + production recipes.
examples/models/gemma4_31b/quant/test_serialize.py Unit tests for nibble packing and safetensors serialization format.
examples/models/gemma4_31b/quant/test_recipe.py Unit tests for regex/layer-filter recipe matching + production recipe regression tests.
examples/models/gemma4_31b/quant/test_quantize.py Unit tests for quantize_weight and quantize_model APIs (CPU + CUDA/HQQ paths).
examples/models/gemma4_31b/quant/test_pack_cuda.py CUDA unit tests for int4/int8 packers and load-and-pack dispatcher behavior.
examples/models/gemma4_31b/quant/serialize.py Canonical quantized weight format + safetensors save/load with versioned metadata.
examples/models/gemma4_31b/quant/recipe.py Declarative quantization recipe/rule objects with regex FQN matching and optional layer filters.
examples/models/gemma4_31b/quant/quantize.py Implements min-max and HQQ quantization into canonical (packing-free) representations.
examples/models/gemma4_31b/quant/pack_cuda.py CUDA-specific packers converting canonical weights into torchao runtime tensor subclasses.
examples/models/gemma4_31b/quant/pack.py Backend-agnostic pack dispatcher that assigns weights/buffers and calls module-type packers.
examples/models/gemma4_31b/quant/init.py Public API re-exports for quant/ package.
examples/models/gemma4_31b/quant/README.md Documentation of the quant framework, data flow, and backend extension points.
examples/models/gemma4_31b/model.py Gemma 4 31B model definition, HF checkpoint loader, ring KV cache for sliding layers, runtime buffer materialization.
examples/models/gemma4_31b/model.md Architecture/design notes for model + quant pipeline.
examples/models/gemma4_31b/main.cpp ExecuTorch CUDA runner driving exported prefill/decode and HF tokenizer decoding.
examples/models/gemma4_31b/inference.py Eager CUDA inference script loading prequantized weights, packing, and generating text.
examples/models/gemma4_31b/export.py Export + lowering pipeline (decode + prefill methods) targeting the CUDA backend.
examples/models/gemma4_31b/init.py Package marker for the new model example.
examples/models/gemma4_31b/README.md User-facing instructions for quantize/export/inference/build/run workflows.
examples/models/gemma4_31b/CMakePresets.json CMake preset for building the Gemma 4 31B CUDA runner.
examples/models/gemma4_31b/CMakeLists.txt CMake build for the Gemma 4 31B runner, linking ExecuTorch + CUDA backend + tokenizer.
examples/models/gemma4/text_decoder/gemma4_norm.py Replaces transformers RMSNorm dependency with a self-contained implementation.
examples/models/gemma4/text_decoder/init.py Exposes attention/norm/MLP primitives used by gemma4_31b for shared numerically-sensitive ops.
Makefile Adds gemma4_31b-cuda build target.
.github/workflows/cuda.yml Adds Gemma 4 31B quant + pipeline tests to the CUDA CI job.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/model.py
Comment thread examples/models/gemma4_31b/model.py
Comment thread examples/models/gemma4_31b/quant/pack.py Outdated
Comment thread examples/models/gemma4_31b/test_pipeline.py Outdated
Comment thread examples/models/gemma4_31b/quant/recipe.py Outdated
Copilot AI review requested due to automatic review settings April 30, 2026 14:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a full Gemma 4 31B-IT example to ExecuTorch, including a new packing-agnostic quantization framework, CUDA packing/export/inference tooling, GGUF import support, a C++ CUDA runner, and a comprehensive test suite integrated into CI.

Changes:

  • Introduce examples/models/gemma4_31b/quant/ canonical quantization framework (recipe → quantize → serialize → pack) with CUDA packers and safetensors persistence.
  • Add Gemma 4 31B-IT model implementation with ring-buffer KV cache for sliding-window layers, plus export/eager inference/runner scripts.
  • Add unit + integration tests (CPU and CUDA) and run them in the CUDA CI workflow.

Reviewed changes

Copilot reviewed 31 out of 31 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
examples/models/gemma4_31b/tests/test_pipeline.py CPU-only integration tests for quantize→save→load roundtrip and fixtures for CUDA tests.
examples/models/gemma4_31b/tests/test_cuda_pipeline.py CUDA integration tests for packing, generation, chunked prefill, and export.
examples/models/gemma4_31b/sampler.py GPU-side Gumbel-max sampler used by the exported model for on-device sampling.
examples/models/gemma4_31b/quantize_and_save.py CLI to quantize HF checkpoints and persist packing-agnostic safetensors checkpoints.
examples/models/gemma4_31b/quant/tests/test_serialize.py Unit tests for canonical serialize/deserialize and nibble pack/unpack.
examples/models/gemma4_31b/quant/tests/test_recipe.py Unit tests for regex/layer-filter recipe matching + production recipe regression tests.
examples/models/gemma4_31b/quant/tests/test_quantize.py Unit tests for canonical quantize/dequantize APIs and model-walking quantization.
examples/models/gemma4_31b/quant/tests/test_pack_cuda.py CUDA unit tests for packing canonical weights into CUDA runtime formats and dispatch.
examples/models/gemma4_31b/quant/tests/test_gguf.py Unit tests validating GGUF Q4_K/Q6_K unpacking against reference formulas.
examples/models/gemma4_31b/quant/serialize.py Canonical CQW representation + safetensors format, nibble packing, save/load.
examples/models/gemma4_31b/quant/recipe.py Declarative quantization recipe/rule/config structures and matching logic.
examples/models/gemma4_31b/quant/quantize.py Canonical quantization implementations (min_max, HQQ) + per-model quantization walk.
examples/models/gemma4_31b/quant/pack_cuda.py CUDA packers from canonical weights to tinygemm/int8 subclass runtime formats.
examples/models/gemma4_31b/quant/pack.py Backend-agnostic pack dispatcher grouping weights per module and applying packers.
examples/models/gemma4_31b/quant/gguf.py GGUF tensor unpacker/streamer to canonical CQW or dense tensors.
examples/models/gemma4_31b/quant/init.py Public API re-exports for quant/ package.
examples/models/gemma4_31b/quant/README.md Documentation of the quant framework layers, data flow, and on-disk format.
examples/models/gemma4_31b/model.py Gemma 4 31B-IT model definition, ring-buffer KV cache, HF load/remap, runtime buffer materialization.
examples/models/gemma4_31b/model.md Architecture/design notes including attention flavors, caching strategy, and export methods.
examples/models/gemma4_31b/main.cpp CUDA ExecuTorch runner driving exported prefill/decode with tokenizer integration.
examples/models/gemma4_31b/inference.py Eager CUDA inference path (load/pack/materialize/compile + generate loop).
examples/models/gemma4_31b/gguf_loader.py Streams GGUF weights, remaps keys, handles tied weights, and packs into runtime model.
examples/models/gemma4_31b/export.py Export+lower pipeline producing .pte/.ptd with decode/prefill methods and CUDA lowering.
examples/models/gemma4_31b/init.py Package marker for the Gemma 4 31B example.
examples/models/gemma4_31b/README.md User-facing docs for quantize/export/inference/build/run flows.
examples/models/gemma4_31b/CMakePresets.json CMake workflow presets to build the CUDA runner.
examples/models/gemma4_31b/CMakeLists.txt CMake build for the Gemma 4 31B CUDA runner binary.
examples/models/gemma4/text_decoder/gemma4_norm.py Replace transformers RMSNorm dependency with a self-contained implementation.
examples/models/gemma4/text_decoder/init.py Export shared primitives (attention helpers, KV cache, MLP, norm) for reuse by gemma4_31b.
Makefile Add gemma4_31b-cuda build target and help entry.
.github/workflows/cuda.yml Run Gemma 4 31B quant and pipeline tests as part of CUDA unit test job.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/quant/quantize.py Outdated
Comment thread examples/models/gemma4_31b/quant/tests/test_gguf.py Outdated
Comment thread .github/workflows/cuda.yml
Comment thread examples/models/gemma4_31b/tests/test_cuda_pipeline.py Outdated
Comment thread examples/models/gemma4_31b/tests/test_pipeline.py
Comment thread examples/models/gemma4_31b/tests/test_cuda_pipeline.py Outdated
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 30, 2026

Claude finished @mergennachin's task in 2m 54s —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand the PR structure
  • Review quantization framework (quant/)
  • Review model implementation (model.py)
  • Review export and inference scripts
  • Review GGUF loader
  • Review C++ runner and build system
  • Review tests
  • Review CI integration
  • Post comprehensive review feedback

Overall Assessment

This is a well-structured, substantial addition (~5300 lines) that introduces a full Gemma 4 31B-IT pipeline. The architecture is clean — the quantization framework separates concerns nicely (recipe → quantize → serialize → pack), and the code is well-organized for reuse with future models. The test coverage is good (78 tests across multiple files), and the CI integration is appropriate.

I've reviewed the full diff and have the following feedback:


Architecture & Design (Positive)

  1. Clean separation of concerns in quant/: recipe declares intent, quantize produces canonical form, serialize handles persistence, pack handles backend-specific conversion. This layering is excellent for extensibility.

  2. Ring-buffer KV cache (model.py:64-98) is a good memory optimization for sliding-window layers. The buf_size = window_size * 2 sizing combined with chunked prefill in the runner is sound.

  3. GGUF support with streaming iteration (iter_gguf_tensors) keeps peak memory low. The tied embed/lm_head handling (dequantize embedding to bf16 for gather, keep lm_head as INT4) is a pragmatic choice.

  4. The pack_one API (pack.py:80-109) enables streaming GGUF loading without buffering all weights before packing — nice design for 31B-scale models.


Issues

1. QuantConfig.bits documentation vs implementation mismatch (quant/recipe.py:21)

The comment says # 4, 8 but the Copilot reviewer noted that tests use bits=6 (via GGUF Q6_K) and the QuantConfig doesn't validate. In practice, _unpack_q6_k produces a CQW with bits=8 and method="gguf_q6_k" (since 6-bit values are widened to int8), so this is mostly a documentation issue. But it would be clearer to document the contract as "storage width: 4 or 8" since that's what the rest of the pipeline actually supports.

2. Sliding-window mask correctness for multi-token prefill (model.py:448-454)

The ring_pos calculation:

ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size

This computes the "logical position" for each ring-buffer slot. When input_pos[0] == 0 and seq_len == buf_size, total_written = buf_size, and ring_pos[j] = j for all j, which is correct (all slots were just written). The mask delta >= 0 ensures causality, and delta < window clips the sliding window.

However, I'd recommend adding a runtime assert or at least a more prominent comment that seq_len <= buf_size is a hard invariant. The C++ runner respects get_max_prefill_chunk, but nothing in the Python model's forward prevents violation. If someone calls model(tokens, input_pos) with seq_len > buf_size, the index_copy_ on line 96 has non-deterministic behavior (as the comment notes) but no error is raised. Consider:

# In RingKVCache.update:
assert input_pos.shape[0] <= self.buf_size, f"seq_len {input_pos.shape[0]} > buf_size {self.buf_size}"

This would catch misuse in eager mode (export would strip it).

3. C++ runner: read_token casts float to uint64 (main.cpp:60-86)

float val = 0.0f;
// ... read float from tensor ...
return static_cast<uint64_t>(val);

The sampler returns argmax(...).float(), so val is a token ID stored as float. The cast float → uint64_t truncates. For token IDs up to 262144 (Gemma 4 vocab), float32 has sufficient precision (exact up to 2^24 = 16M). This is fine in practice but worth noting — if vocab ever exceeds 16M, this would silently produce wrong IDs. A comment like "// safe: vocab_size < 2^24 so float→int is exact" would help.

4. _INT8_PER_AXIS group_size is hardcoded to hidden_size (quantize_and_save.py:51)

_INT8_PER_AXIS = QuantConfig(bits=8, group_size=5376, symmetric=True, method="min_max")

This hardcodes group_size=5376 (the hidden_size). If someone reuses this recipe with a model whose embedding dim differs, the quantization will produce wrong-shaped scales. Since the recipe is specifically for Gemma 4 31B this is acceptable, but a comment noting the coupling would be helpful.

5. GGUF test import handling (quant/tests/test_gguf.py:28-29)

if _HAS_GGUF:
    from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor

The @unittest.skipUnless(_HAS_GGUF, ...) decorators handle skipping correctly. However, dequantize_weight and serialize/deserialize are imported unconditionally (line 30-31), which is fine since they don't depend on gguf. The CI workflow correctly installs gguf before running tests (line 152 of cuda.yml). This is properly handled.

6. Chunked prefill test uses allclose appropriately (tests/test_cuda_pipeline.py:129-138)

The test uses atol=1e-2, rtol=1e-3 which is reasonable for comparing across different CUDA kernel execution paths. Good fix from the earlier Copilot suggestion about torch.equal being too strict.


Minor Suggestions

  1. model.py:486-495 — The temperature is None branch returns full (B, T, V) logits, while the temperature branch returns (B, 1) token IDs. The return type depends on runtime input, which makes the type signature unclear. This is documented in the docstring, but a @torch.no_grad() decorator on forward would also be appropriate since this model is inference-only.

  2. inference.py:98-101 — Eager prefill is one-token-at-a-time. This is intentional (comment explains it's simpler for torch.compile), but for a 31B model with long prompts, this could be slow. The exported PTE path handles this properly via chunked prefill. Consider adding a note in the README about this trade-off.

  3. pack_cuda.py:164pack_int4_for_cuda is called with device="cuda" then immediately moved to CPU:

    module.weight = nn.Parameter(packed.data.to("cpu"), requires_grad=False)
    torch.cuda.empty_cache()

    This is because tinygemm packing requires CUDA, but the model is built on CPU for export. The torch.cuda.empty_cache() after each linear is a good practice for 31B-scale models to avoid OOM during the packing phase.

  4. gguf_loader.py:128config = Gemma4_31BConfig(max_seq_len=max_seq_len) uses all defaults. This is correct for the community GGUF files (which are the standard 31B architecture), but if GGUF metadata contains different hyperparameters, they'd be ignored. Fine for now since GGUF only stores the standard config.

  5. CMakeLists.txt:51-53FATAL_ERROR when EXECUTORCH_BUILD_CUDA is not set is appropriate since this runner is CUDA-only.


Test Coverage Assessment

The test suite is comprehensive:

  • test_recipe.py: Regex matching, layer filters, production recipe regression tests
  • test_quantize.py: Output structure, roundtrips, error cases, HQQ paths
  • test_serialize.py: Nibble packing, save/load, 3D tensors, metadata
  • test_pack_cuda.py: Module-type dispatch, matmul accuracy, mixed precision, MoE grouping
  • test_gguf.py: Q4_K/Q6_K unpacking vs reference formulas, serialize roundtrip
  • test_pipeline.py: Quantize→save→load integration
  • test_cuda_pipeline.py: Pack→generate, chunked prefill correctness, export e2e

Missing: No test for the RingKVCache in isolation (though it's exercised by test_chunked_prefill_matches_sequential).


Summary

This is a high-quality PR with clean architecture, good test coverage, and solid engineering choices for memory efficiency at 31B scale. The main actionable items are:

  1. Add a runtime assert in RingKVCache.update for seq_len <= buf_size
  2. Update the QuantConfig.bits comment to reflect actual supported values (storage widths 4 and 8)
  3. Comment the float→uint64 cast safety in main.cpp

None of these are blocking — the code is correct as written, and the invariants are upheld by the runner/export logic. These are defensive hardening suggestions.


Copilot AI review requested due to automatic review settings April 30, 2026 14:50
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 30, 2026

Claude finished @mergennachin's task in 2m 30s —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand the PR structure
  • Review quantization framework (quant/)
  • Review model implementation (model.py)
  • Review export and inference scripts
  • Review GGUF loader
  • Review C++ runner and build system
  • Review tests
  • Review CI integration
  • Post comprehensive review feedback

Overall Assessment

This is a high-quality, well-architected PR (~5500 lines) adding a full Gemma 4 31B-IT pipeline. The separation of concerns is clean — recipe declares intent, quantize produces canonical form, serialize persists, pack converts to backend format. Test coverage is comprehensive (78 tests), and the PR addresses the previous review's concerns well (added the RingKVCache assert, fixed symmetric INT4 dequant, added GGUF tests with proper skip handling).


Architecture & Design (Strengths)

  1. Layered quantization abstractionquant/ cleanly separates recipe → quantize → serialize → pack. Adding a new backend (e.g., Metal) requires only a new pack_metal.py; adding a new model requires only a QuantRecipe. This will pay dividends for Qwen 3.5 MoE and future models.

  2. Ring-buffer KV cache (model.py:64-101) — Sizing to window_size * 2 with the runtime assert on line 96 is a good defensive choice. The sliding mask logic in _build_masks correctly handles wraparound by computing ring_pos from total_written.

  3. Streaming GGUF loading (gguf_loader.py) — Processing one tensor at a time via iter_gguf_tensors + pack_one keeps peak memory low during loading of a 31B model. The untied embed/lm_head handling (dequantize embedding for gather, keep lm_head quantized for matmul) is pragmatic and well-documented.

  4. Chunked prefill in the C++ runner (main.cpp:242-287) — Correctly queries get_max_prefill_chunk metadata and respects the ring-buffer limit. The fallback to kMaxSeqLen - 1 when metadata is absent is safe.


Issues & Suggestions

1. read_token float→uint64 cast should use llrintf (Fixed — good)

main.cpp:85 uses llrintf(val) to round the float to the nearest integer before casting. This is the right approach — raw static_cast<uint64_t>(val) truncates toward zero, which could produce off-by-one errors for token IDs. Well done.

2. _INT8_PER_AXIS hardcodes group_size=5376 (quantize_and_save.py:51-53)

_INT8_PER_AXIS = QuantConfig(bits=8, group_size=5376, symmetric=True, method="min_max")

This couples the recipe to Gemma 4 31B's hidden_size. Since the recipe is named GEMMA4_31B_* and lives in the model-specific file, this is acceptable. But if this framework is promoted to be shared across models, consider deriving from the config:

# Future: group_size=config.hidden_size for per-axis quantization

Fix this →

3. Potential integer overflow in _build_masks ring_pos calculation (model.py:455)

ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size

When total_written is 0 (i.e., input_pos[0] == 0 and seq_len == 0), total_written - 1 underflows to -1 for signed int64 tensors. In practice this can't happen (forward is never called with empty input), but it's worth noting. The actual execution path always has seq_len >= 1, so total_written >= 1.

4. GGUF test skip handling is correct (quant/tests/test_gguf.py:21-28)

The earlier review flagged that GGUF tests might fail without the gguf package. The current code handles this correctly:

  • Lines 21-25: try/except guards _HAS_GGUF
  • Line 28: Conditional import of unpack_gguf_tensor only when _HAS_GGUF=True
  • Lines 30-31: dequantize_weight and serialize/deserialize are imported unconditionally (they don't need gguf)
  • All test classes use @unittest.skipUnless(_HAS_GGUF, ...) decorators
  • CI (cuda.yml:152) installs gguf before running tests

This is properly handled now.

5. _move_to_cuda preserves tensor subclass identity (inference.py:48-55)

for name, p in model.named_parameters():
    parts = name.rsplit(".", 1)
    parent = model.get_submodule(parts[0]) if len(parts) > 1 else model
    setattr(parent, parts[-1],
            torch.nn.Parameter(p.data.to("cuda"), requires_grad=False))

This iterates named_parameters() while mutating them via setattr. In CPython this is safe because named_parameters() yields from a snapshot of the module tree (it doesn't lazily walk), but it's fragile. More importantly, p.data.to("cuda") on a Int4TilePackedTo4dTensor subclass relies on the subclass implementing __torch_dispatch__ for the to op correctly. The docstring correctly notes this intent — just flagging that if torchao changes the subclass dispatch behavior, this could silently break. A more defensive approach would be to check type(p.data) after the move, but this is minor.

6. Export caps prefill at min(max_seq_len - 1, sliding_window * 2) (export.py:167)

max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2)

For the default config (max_seq_len=4096, sliding_window=1024), this gives max_prefill=2048. The full-attention layers use a flat Gemma4KVCache of size max_seq_len=4096, so they have no issue. The sliding layers use a ring buffer of size 2048, matching exactly. This is correct.

However, consider that Dim("seq_len", min=2, max=max_prefill) means the dynamic shape for prefill ranges from 2 to 2048. If a user passes max_seq_len=512 (< 2*sliding_window), max_prefill=511, and the ring buffer is still sized to 2048 but the dynamic dim is capped at 511. This is fine — it's a conservative limit.

7. Q4_K zero handling when scale=0 (quant/gguf.py:77-79)

zero_std = torch.where(
    eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min)
)

The test_zero_scale_produces_zero test validates this edge case — when scale=0, zero_std is forced to 0, avoiding NaN from division. The comment in the test references this as a fixed regression. Good.

8. Sampler numerical stability (sampler.py:40-41)

noise = torch.rand_like(logits)
gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20)

The double-log Gumbel trick with 1e-20 epsilon is standard. However, torch.rand_like can produce exactly 0.0 (uniform on [0, 1)), and -log(1e-20) ≈ 46, so -log(46 + 1e-20) ≈ -3.8. This is fine numerically — no NaN or Inf can occur given the epsilon. The temperature clamp at 1e-6 also prevents division issues.


Minor Observations

  1. model.py:486-498 — The forward return type changes based on temperature (returns (B, T, V) logits when None, (B, 1) token IDs otherwise). This is documented but means type annotations can't express the full contract. Fine for an inference-only model.

  2. inference.py:96-101 — One-token-at-a-time prefill is intentionally chosen for torch.compile friendliness (uniform shape = one graph). The docstring explains the tradeoff clearly. For a 31B model with 4096-token prompts this is slow (~4s on A100), but the PTE path with chunked multi-token prefill is the production path.

  3. pack_cuda.py:164 — Moving INT4 packed weights from CUDA back to CPU (packed.data.to("cpu")) with torch.cuda.empty_cache() is good memory hygiene for the packing phase. Each linear is packed individually on CUDA to avoid accumulating GPU memory for all 60 layers simultaneously.

  4. quant/tests/test_pack_cuda.py — I didn't read this fully but its presence (482 lines) indicates thorough CUDA packer testing including matmul accuracy validation.


Test Coverage Assessment

The test suite is comprehensive and well-organized:

File Scope Key tests
test_recipe.py Regex matching, layer filters, production recipe regression 14 test cases
test_quantize.py quantize_weight/quantize_model API contract, HQQ paths 12 test cases
test_serialize.py Nibble packing, save/load, 3D tensors, versioned headers ~15 test cases
test_pack_cuda.py Module-type dispatch, matmul accuracy, mixed precision, MoE ~20 test cases
test_gguf.py Q4_K/Q6_K unpacking vs reference formulas, serialize roundtrip 7 test cases
test_pipeline.py Quantize→save→load integration, RingKVCache isolation 7 test cases
test_cuda_pipeline.py Pack→generate, chunked prefill correctness, export e2e 4 test cases

Good coverage. The RingKVCache unit tests (test_pipeline.py:212-266) are a welcome addition addressing the earlier review's gap.


Summary

This is a well-executed, production-quality addition. The architecture decisions are sound, the code is clean and readable (comments only where non-obvious), and the test coverage is thorough. The main items from previous reviews have been addressed. No blocking issues remain.

Actionable items (non-blocking):

  1. Comment the group_size=5376 coupling to hidden_size in quantize_and_save.py:51
  2. Consider adding @torch.no_grad() decorator to Gemma4_31B.forward since this is inference-only

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a full Gemma 4 31B-IT text-only pipeline to the ExecuTorch examples, including a reusable packing-agnostic quantization framework (recipe/quantize/serialize/pack), GGUF import, CUDA packing/export/inference flows, and a CUDA runner, with CI coverage.

Changes:

  • Introduces examples/models/gemma4_31b/ model implementation (ring-buffer KV cache), export/inference scripts, GGUF loader, and C++ CUDA runner + build targets.
  • Adds a new quant/ framework (recipes, min-max + HQQ quantization, safetensors format, CUDA packing, GGUF Q4_K/Q6_K unpack).
  • Adds unit/integration tests and wires them into the CUDA GitHub Actions workflow.

Reviewed changes

Copilot reviewed 31 out of 31 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
examples/models/gemma4_31b/tests/test_pipeline.py CPU-only pipeline + RingKVCache tests for quant/save/load and cache behavior
examples/models/gemma4_31b/tests/test_cuda_pipeline.py CUDA integration tests for pack/infer/export + chunked prefill equivalence
examples/models/gemma4_31b/sampler.py GPU-side Gumbel-max sampler (mirrors Qwen sampler behavior)
examples/models/gemma4_31b/quantize_and_save.py CLI to quantize HF checkpoint and save canonical safetensors checkpoint
examples/models/gemma4_31b/quant/tests/test_serialize.py Unit tests for canonical format + nibble packing + safetensors I/O
examples/models/gemma4_31b/quant/tests/test_recipe.py Unit tests for regex/layer-filter recipe matching + production recipe regression
examples/models/gemma4_31b/quant/tests/test_quantize.py Unit tests for min-max + HQQ quantize/dequantize and quantize_model behavior
examples/models/gemma4_31b/quant/tests/test_pack_cuda.py CUDA unit tests for packers (int4 tinygemm, int8 intx, dispatch/grouping)
examples/models/gemma4_31b/quant/tests/test_gguf.py Unit tests for GGUF Q4_K/Q6_K unpacking and serialize roundtrip
examples/models/gemma4_31b/quant/serialize.py CanonicalQuantizedWeight + serialize/deserialize + safetensors save/load
examples/models/gemma4_31b/quant/recipe.py QuantConfig/QuantRule/QuantRecipe declarative matching logic
examples/models/gemma4_31b/quant/quantize.py min-max + HQQ quantize_weight/dequantize_weight + per-model quantization
examples/models/gemma4_31b/quant/pack_cuda.py CUDA packers for Linear/Embedding and load+pack convenience wrapper
examples/models/gemma4_31b/quant/pack.py Backend-agnostic pack_model/pack_one dispatch + grouping by parent module
examples/models/gemma4_31b/quant/gguf.py GGUF tensor unpack + streaming iterator to canonical representation
examples/models/gemma4_31b/quant/init.py Public API exports for quant framework and CUDA packers
examples/models/gemma4_31b/quant/README.md Framework overview, dataflow, and backend/model extension guidance
examples/models/gemma4_31b/model.py Gemma4 31B model, ring-buffer KV cache, HF loader, runtime buffer materialization
examples/models/gemma4_31b/model.md Architecture/design notes + export/quantization details
examples/models/gemma4_31b/main.cpp CUDA runner driving exported prefill/decode methods with chunked prefill
examples/models/gemma4_31b/inference.py Eager CUDA inference with torch.compile + GGUF/prequantized load paths
examples/models/gemma4_31b/gguf_loader.py Streams GGUF tensors, remaps keys, handles tied embed/lm_head, packs model
examples/models/gemma4_31b/export.py Export+lower (decode+prefill) to ExecuTorch with CUDA backend and metadata
examples/models/gemma4_31b/init.py Package marker for gemma4_31b example
examples/models/gemma4_31b/README.md End-to-end usage docs for quantize/export/infer/build/run
examples/models/gemma4_31b/CMakePresets.json CMake workflow presets for building the runner (CUDA preset)
examples/models/gemma4_31b/CMakeLists.txt Runner build configuration + CUDA backend linkage
Makefile Adds gemma4_31b-cuda build target
.github/workflows/cuda.yml Runs gemma4_31b quant + pipeline tests in CUDA CI
examples/models/gemma4/text_decoder/gemma4_norm.py Replaces transformers RMSNorm import with self-contained implementation
examples/models/gemma4/text_decoder/init.py Exposes additional shared primitives used by gemma4_31b model

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/main.cpp
Comment thread examples/models/gemma4_31b/quant/recipe.py
Comment thread examples/models/gemma4_31b/quant/serialize.py Outdated
Comment thread examples/models/gemma4_31b/quant/quantize.py
Comment thread examples/models/gemma4_31b/quant/quantize.py Outdated
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 30, 2026

Claude finished @mergennachin's task in 2m 30s —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand the PR structure
  • Review model implementation (model.py)
  • Review quantization framework (quant/)
  • Review GGUF loader and streaming support
  • Review export and inference scripts
  • Review C++ runner
  • Review tests and CI integration
  • Post comprehensive review feedback

Overall Assessment

This is a well-structured, high-quality addition (~5600 lines across 31 files) introducing a full Gemma 4 31B-IT text-only pipeline. The architecture is clean — the quantization framework separates concerns into orthogonal layers (recipe → quantize → serialize → pack), the model implementation is faithful to the HF reference, and the streaming GGUF support keeps peak memory tractable. Test coverage is comprehensive (78 tests), and the PR has clearly evolved through multiple review iterations addressing earlier feedback.

The codebase is ready for merge with minor non-blocking suggestions below.


Architecture Strengths

  1. Layered quantization frameworkquant/ cleanly separates concerns: recipe.py (what to quantize), quantize.py (how to quantize), serialize.py (persistence format), pack.py (runtime dispatch), pack_cuda.py (backend-specific). Adding a new backend (Metal, XLA) requires only a new pack_*.py; adding a model requires only a recipe.

  2. Ring-buffer KV cache (model.py:64-101) — Sizing to 2 × sliding_window instead of max_seq_len saves ~75% of KV memory for sliding layers. The runtime assert on line 95-97 guards against oversized prefill — good defensive programming.

  3. Streaming I/O everywhereiter_load (serialize.py:238-281), iter_gguf_tensors (gguf.py:201-219), and load_and_pack_for_cuda (pack_cuda.py:200-224) all process one weight at a time. At 31B scale, this is the difference between 17GB peak and 60GB+ peak during model loading.

  4. Shared primitives — The numerically sensitive ops (RMSNorm, RoPE, KV cache, MLP) are shared with the E2B/E4B variant via examples.models.gemma4.text_decoder, avoiding drift.

  5. Chunked prefill in the C++ runner (main.cpp:242-293) — Correctly queries get_max_prefill_chunk metadata and chunks accordingly, with cudaDeviceSynchronize (line 308) between prefill and decode to flush shared KV-cache writes.


Issues & Suggestions

1. _build_masks ring_pos computation when total_written == 0 (model.py:453-455)

total_written = input_pos[0] + seq_len
j = torch.arange(buf_size, dtype=torch.long, device=input_pos.device)
ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size

When input_pos[0] == 0 and seq_len >= 1, total_written >= 1, so total_written - 1 >= 0 and the floor division is well-defined. The invariant seq_len >= 1 is guaranteed by the export shapes (decode T=1, prefill T≥2). However, consider that total_written - 1 - j can be negative for j > total_written - 1 (i.e., slots that haven't been written yet). In that case (negative // buf_size) * buf_size yields a negative ring_pos, and the mask clause ring_pos >= 0 (line 457) correctly excludes those slots. This is sound.

No action needed — just documenting my verification.

2. QuantConfig.group_size comment is stale (recipe.py:24)

group_size: int  # 32, 64, 128

The actual usage includes group_size=5376 (per-axis embedding), group_size=16 (Q6_K), and group_size=32 (Q4_K, standard). The comment should say "any positive divisor of the last dim" or similar.

Fix this →

3. quantize_model() prints to stdout unconditionally (quantize.py:241-242)

print(f"  Quantized {i + 1}/{n_params}: {fqn}", end="\r")

Since this is a reusable library function called by tests, this pollutes test output. Consider gating behind a verbose parameter (default True for backward compat) or using logging.

4. _INT8_PER_AXIS hardcodes group_size=5376 (quantize_and_save.py:51-53)

_INT8_PER_AXIS = QuantConfig(  # group_size = hidden_size (5376) for Gemma 4 31B
    bits=8, group_size=5376, symmetric=True, method="min_max"
)

The comment already explains the coupling, which is good. When promoting this framework for Qwen 3.5 MoE, this will need to be parameterized. Fine for now.

5. C++ runner: missing <cstring> and <cmath> includes (main.cpp:62-88)

read_token uses memcpy (needs <cstring>) and llrintf (needs <cmath>). Currently these compile because they're transitively included via gflags/CUDA headers, but this is fragile.

Fix this →

6. deserialize() doesn't reshape loaded tensors (serialize.py:190-206)

deserialize() records original_shape from metadata but never reshapes qdata to match for 8-bit weights (4-bit gets unpacked to the right shape via _nibble_unpack). The CanonicalQuantizedWeight.__post_init__ validates element counts but not shapes. A corrupted file with transposed dimensions would pass validation but produce wrong results silently. Consider adding an explicit reshape:

qdata = qdata.reshape(original_shape)

after the nibble-unpack branch.

7. _move_to_cuda mutates parameters during iteration (inference.py:48-55)

for name, p in model.named_parameters():
    parts = name.rsplit(".", 1)
    parent = model.get_submodule(parts[0]) if len(parts) > 1 else model
    setattr(parent, parts[-1],
            torch.nn.Parameter(p.data.to("cuda"), requires_grad=False))

named_parameters() yields from a generator that walks the module tree. Mutating modules during iteration is safe in CPython (the generator captures module references upfront), but it's implementation-dependent. A safer pattern is list(model.named_parameters()). This is minor — the current code works correctly.

8. Sampler's temperature is None return type divergence (sampler.py:35-41)

if temperature is None:
    return logits  # (B, V)
...
return (logits + gumbel).argmax(dim=-1, keepdim=True).float()  # (B, 1)

The return type depends on the runtime value of temperature. This is intentional (documented in model.py:476-479) and works because the two code paths are exported as separate methods (prefill returns logits for the last chunk, decode uses temperature). No action needed, but a type annotation like -> torch.Tensor (current) is the best we can do here.


Test Coverage

The test suite is well-organized and thorough:

File Tests Key coverage
quant/tests/test_recipe.py 14 Regex matching, layer filters, production recipe regression
quant/tests/test_quantize.py 12 quantize_weight/dequantize_weight roundtrip, HQQ paths, error cases
quant/tests/test_serialize.py ~15 Nibble pack/unpack, save/load, 3D tensors, format versioning
quant/tests/test_pack_cuda.py ~20 Module dispatch, matmul accuracy, mixed precision, MoE grouping
quant/tests/test_gguf.py 7 Q4_K/Q6_K vs reference formulas, serialize roundtrip, edge cases
tests/test_pipeline.py 7 Quantize→save→load, RingKVCache isolation, corrupted checkpoint
tests/test_cuda_pipeline.py 4 Pack→generate, chunked prefill correctness, export e2e

The GGUF test skip handling (test_gguf.py:21-28) is properly implemented — @unittest.skipUnless(_HAS_GGUF, ...) on each class, conditional import of unpack_gguf_tensor, and CI installs gguf before running (cuda.yml:152).

The RingKVCache unit tests (test_pipeline.py:212-266) cover sequential write, wraparound, multi-token, and oversized-prefill assertion — good targeted coverage for the most subtle piece of logic.


Summary

This is production-quality work with sound architecture decisions for memory efficiency at 31B scale. The main actionable items are:

  1. Add <cstring> and <cmath> includes in main.cpp (fragile transitive dependency)
  2. Update group_size comment in recipe.py to reflect actual supported values
  3. Consider reshaping qdata in deserialize() for 8-bit weights (defensive hardening)

None of these are blocking — the code is correct as written and the invariants are upheld by the export/runner logic. Previous review iterations (symmetric INT4 dequant fix, RingKVCache assert, GGUF skip handling, allclose tolerances) have all been addressed properly.




@dataclass
class CanonicalQuantizedWeight:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is subclass support for this "canonical" format that is later converted to specialized formats.

The basic flow is:

  • On your device-specific subclass, define a class method constructor "from_{canonical_format}_tensor", e.g., Int4PreshuffledTensor.from_int4_tensor(tensor)
  • Load "canonical" tensor subclass checkpoint
  • Iterate through canonical tensors, applying the conversion above.

IIRC, this is the path used by torchao in vLLM (cc @jerryzh168 to confirm).

CanonicalQuantizedWeight feels like a new abstraction that we'd maintain to do the same thing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy Good point.

I'll use Int4Tensor and IntxUnpackedToInt8Tensor as they seem like canonical/portable format that can serialized to/from safetensors.

Will delete CanonicalQuantizedWeight

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy Made the change. Thank for the recommendation

Copilot AI review requested due to automatic review settings May 1, 2026 16:50
@mergennachin mergennachin force-pushed the gemma4-31b-quant-framework branch from 49cfc55 to 69474e3 Compare May 1, 2026 16:50
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented May 1, 2026

Claude encountered an error —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand recent changes
  • Review quantization framework (quant/) — especially torchao subclass migration
  • Review model implementation (model.py)
  • Review export, inference, and GGUF loader
  • Review C++ runner and build system
  • Review tests and CI integration
  • Post comprehensive review feedback

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new ExecuTorch example for Gemma 4 31B-IT with a reusable INT4/INT8 weight-only quantization + serialization + packing pipeline (torchao tensor subclasses + safetensors), along with export/inference tooling, a CUDA runner, and extensive test coverage.

Changes:

  • Introduce examples/models/gemma4_31b/ model implementation, export/inference scripts, GGUF loader, and CUDA runner (CMake + Make target).
  • Add quant/ framework (recipe → quantize → pack) using torchao tensor subclasses and torchao safetensors flatten/unflatten.
  • Add CPU/CUDA pipeline tests + quant unit tests; extend CUDA CI workflow to execute them.

Reviewed changes

Copilot reviewed 30 out of 30 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
examples/models/gemma4_31b/tests/test_pipeline.py CPU-only integration tests for quantize/save/load + unit tests for RingKVCache and GGUF key mapping.
examples/models/gemma4_31b/tests/test_cuda_pipeline.py CUDA integration tests for generate, chunked prefill correctness, and export paths.
examples/models/gemma4_31b/sampler.py GPU-side Gumbel-max sampler used to keep exported programs single-output across temperatures.
examples/models/gemma4_31b/quantize_and_save.py CLI to quantize HF checkpoints on CPU and save torchao-subclass safetensors checkpoints.
examples/models/gemma4_31b/quant/tests/test_safetensors_roundtrip.py Smoke tests for safetensors roundtrip of torchao tensor subclasses.
examples/models/gemma4_31b/quant/tests/test_recipe.py Unit tests for QuantRecipe matching + regression tests for production recipes.
examples/models/gemma4_31b/quant/tests/test_quantize.py Unit tests for quantize/dequantize behavior and error paths (incl. HQQ).
examples/models/gemma4_31b/quant/tests/test_pack_cuda.py CUDA unit tests for packing INT4/INT8 weights and model-level pack/load paths.
examples/models/gemma4_31b/quant/tests/test_gguf.py Tests for GGUF Q4_K/Q6_K unpacking correctness + safetensors roundtrip.
examples/models/gemma4_31b/quant/recipe.py Declarative quantization recipe types (QuantConfig/Rule/Recipe) with regex + layer filtering.
examples/models/gemma4_31b/quant/quantize.py Quantize/dequantize and model-walk quantization producing torchao tensor subclasses.
examples/models/gemma4_31b/quant/pack_cuda.py CUDA packers for nn.Linear / nn.Embedding (INT4 tinygemm + INT8 intx pass-through).
examples/models/gemma4_31b/quant/pack.py Backend-agnostic dispatch for packing state dicts into meta-built runtime models.
examples/models/gemma4_31b/quant/gguf.py GGUF tensor unpacking (Q4_K/Q6_K/F16/F32) into torchao subclasses with streaming iterator.
examples/models/gemma4_31b/quant/init.py Re-exports for quant package public API.
examples/models/gemma4_31b/quant/README.md Documentation of the quant framework dataflow and on-disk format (torchao safetensors).
examples/models/gemma4_31b/model.py Export-friendly Gemma 4 31B model with hybrid attention + ring-buffer KV cache + sampling.
examples/models/gemma4_31b/model.md Architecture/design notes covering attention variants, export methods, quantization, runtime buffers.
examples/models/gemma4_31b/main.cpp CUDA runner driving exported prefill/decode with chunking and BOS/EOS handling.
examples/models/gemma4_31b/inference.py Eager CUDA inference path (optionally torch.compile) for prequantized or GGUF-loaded models.
examples/models/gemma4_31b/gguf_loader.py Streams GGUF tensors, remaps keys, unties embed/lm_head behavior, and packs for backend.
examples/models/gemma4_31b/export.py Export+lower pipeline producing shared-buffer prefill/decode methods + metadata constants.
examples/models/gemma4_31b/init.py Package marker for the new example.
examples/models/gemma4_31b/README.md End-user instructions and recommended workflows (quantize → export/infer).
examples/models/gemma4_31b/CMakePresets.json Presets to build the Gemma 4 31B CUDA runner.
examples/models/gemma4_31b/CMakeLists.txt CMake target for runner; enforces CUDA build and links required ExecuTorch extensions.
examples/models/gemma4/text_decoder/gemma4_norm.py Removes transformers dependency by re-implementing Gemma4 RMSNorm in-tree.
examples/models/gemma4/text_decoder/init.py Exposes additional gemma4 text_decoder primitives (attention helpers, norms, MLP).
Makefile Adds gemma4_31b-cuda build target and help entry.
.github/workflows/cuda.yml Runs Gemma 4 31B quant + pipeline tests in CUDA CI (installs gguf).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/README.md Outdated
Comment thread examples/models/gemma4_31b/quant/quantize.py
Comment thread examples/models/gemma4_31b/quant/pack_cuda.py Outdated
@mergennachin mergennachin force-pushed the gemma4-31b-quant-framework branch from 69474e3 to 551f3b0 Compare May 1, 2026 17:27
Copilot AI review requested due to automatic review settings May 1, 2026 21:02
@mergennachin mergennachin force-pushed the gemma4-31b-quant-framework branch from 551f3b0 to 2604159 Compare May 1, 2026 21:02
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 31 out of 31 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/main.cpp
Comment thread examples/models/gemma4_31b/quant/README.md
Comment thread examples/models/gemma4_31b/quant/serialize.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 30 out of 30 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/quant/pack_cuda.py Outdated
Comment thread examples/models/gemma4_31b/quant/pack_cuda.py Outdated
Comment thread examples/models/gemma4_31b/quant/pack_cuda.py Outdated
Comment thread examples/models/gemma4_31b/main.cpp
Copilot AI review requested due to automatic review settings May 4, 2026 20:33
@mergennachin mergennachin force-pushed the gemma4-31b-quant-framework branch from b26fb75 to d488474 Compare May 4, 2026 20:33
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 35 out of 35 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread backends/cuda/transforms/int4_linear_dispatch.py Outdated
Comment on lines +28 to +47
#include <cinttypes>
#include <fstream>
#include <string>
#include <vector>

#include <executorch/runtime/platform/platform.h>
#include <executorch/runtime/platform/types.h>
extern "C" void et_pal_emit_log_message(
ET_UNUSED et_timestamp_t timestamp,
et_pal_log_level_t level,
const char* filename,
ET_UNUSED const char* function,
size_t line,
const char* message,
ET_UNUSED size_t length) {
if (level < 'W') {
return;
}
fprintf(stderr, "%c [%s:%zu] %s\n", (char)level, filename, line, message);
}
GGUF names (e.g., ``blk.0.attn_q.weight``); the caller handles key
remapping. GGUF shapes are reversed to PyTorch convention automatically.
"""
from gguf import GGUFReader
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious do we need gguf support even for enablement?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd like to do apples-to-apples latency comparison on a same checkpoint

…uTorch

Text-only export of Gemma 4 31B-IT to ExecuTorch with the CUDA backend
and INT4/INT8 weight quantization via a new packing-agnostic quant/
framework.

The quant/ package separates quantization into four concerns:
  - recipe.py: declarative QuantRecipe with regex FQN matching
  - quantize.py: produces CanonicalQuantizedWeight (min_max, HQQ)
  - serialize.py: save/load to safetensors with versioned headers
  - pack.py + pack_cuda.py: per-module packer dispatch for CUDA

Two production recipes: "default" (INT4 min_max + INT8 embedding) and
"sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ elsewhere).

Sliding window attention uses a ring-buffer KV cache (2x window size)
for the 50 sliding layers, saving memory for long sequences. The 10
full-attention layers use a standard flat KV cache.

Includes C++ runner (main.cpp), eager inference script, and 60+ unit
and integration tests across quant/ and pipeline test files.
- Sliding window layers use RingKVCache (2×window) instead of flat
  max_seq_len buffer, reducing KV cache memory for long sequences.
- Prefill is capped to ring buffer size; the C++ runner chunks longer
  prompts automatically via get_max_prefill_chunk metadata.
- Both recipes now quantize embed_tokens to INT8 per-axis (~1.4 GB
  savings vs bf16). Embedding packer uses IntxUnpackedToInt8Tensor
  which supports gather.
- pack_model handles top-level FQNs (no parent module).
- C++ runner aligned with Qwen patterns: #ifdef guards for non-CUDA
  builds, better weight_sharing error handling, cudaDeviceSynchronize
  between prefill and decode.
- Test suite split into test_pipeline.py (CPU) and test_cuda_pipeline.py
  (CUDA) with shared fixtures. New chunked prefill correctness test.
- Prequantized checkpoint available at
  huggingface.co/SocialLocalMobile/gemma-4-31B-it-HQQ-INT4.
- Added Gemma 4 31B tests to cuda.yml CI workflow.
- Cleaned up stale terminology, docstrings, and comments throughout.
- quant/gguf.py: unpack Q4_K/Q6_K GGUF blocks to CanonicalQuantizedWeight,
  with iter_gguf_tensors for streaming (low peak memory). Validated
  against original bf16 weights (Q4_K: 7.9%, Q6_K: 1.9% error).
- gguf_loader.py: Gemma 4 31B GGUF key mapping + load_gguf_model.
  Handles tied embed/lm_head: embedding dequantized to bf16 (gather),
  lm_head keeps Q4_K (tinygemm matmul).
- export.py and inference.py: --gguf flag for direct GGUF file loading.
- quant/quantize.py: dequantize_weight (inverse of quantize_weight).
- quant/pack.py: pack_one for single-weight streaming; pack_model
  delegates to pack_one for unquantized, groups quantized by parent
  for multi-weight modules (MoE-compatible).
- quant/serialize.py: CanonicalQuantizedWeight.__post_init__ validation
  (dtype, shape, symmetric/zero consistency).
- Tests moved to tests/ folders (quant/tests/ and tests/).
- dequantize_weight now subtracts 8 from symmetric 4-bit qdata (stored
  as unsigned [0,15]) before scaling, matching the quantize_weight shift
- Guard test_gguf.py with skipUnless so CI doesn't break without gguf
- Install gguf in cuda.yml for GGUF test coverage
- Use torch.allclose instead of torch.equal for chunked prefill logit
  comparison to avoid CUDA FP flakiness
- Fix Usage docblock paths in test_pipeline.py and test_cuda_pipeline.py
- Fix float→uint64 truncation in main.cpp read_token (use llrintf)
- Add assert in RingKVCache.update to catch seq_len > buf_size misuse
- Add RingKVCache unit tests (sequential, wraparound, multi-token, assert)
- Add CanonicalQuantizedWeight __post_init__ validation error path tests
- Add GGUF Q4_K through tinygemm pack pipeline test (asymmetric)
- Add 8-bit asymmetric matmul test
- Add F16 GGUF tensor type test
- Document QuantConfig.bits as storage width and _INT8_PER_AXIS coupling
- serialize.py: add iter_load() generator that streams weights one at a
  time from safetensors, keeping peak memory proportional to the largest
  single weight instead of loading all weights into memory at once.
- pack_cuda.py: rewrite load_and_pack_for_cuda to use iter_load for
  streaming — avoids ~40 GB peak memory when loading the 31B checkpoint.
- __init__.py: remove low-level CUDA packer internals (pack_int4_for_cuda,
  pack_int8_for_cuda, pack_linear_for_cuda, pack_embedding_for_cuda) from
  the public API. Tests import these directly from pack_cuda.py.
Gemma's HuggingFace tokenizer does not auto-prepend BOS. Without it
the model's logits collapse. Add --bos_id (default 2) to prepend and
--eos_id (default 1) as a fallback stop token.
Delete the custom CanonicalQuantizedWeight dataclass and serialize.py
format. Quantized weights are now stored as torchao's native Int4Tensor
(4-bit) and IntxUnpackedToInt8Tensor (8-bit) subclasses, serialized via
torchao's safetensors integration.

Key changes:
- quantize_weight returns Int4Tensor or IntxUnpackedToInt8Tensor
- quantize_model returns a single state_dict (not two dicts)
- 8-bit quantization done in float32 to avoid bf16 precision loss
  (manual quantize + direct IntxUnpackedToInt8Tensor construction)
- Sensitive recipe uses HQQ asymmetric INT4 (scale + zero optimization)
- pack_model takes a single state_dict, dispatches by isinstance
- pack.py uses TorchAOBaseTensor for quantized weight detection
- GGUF unpacker produces Int4Tensor/IntxUnpackedToInt8Tensor directly
- serialize.py dissolved — callers inline torchao safetensors directly

Breaking change: existing prequantized checkpoints (old format) must
be regenerated with quantize_and_save.py.
- Use .detach() instead of .data when moving packed INT4 weight to CPU
  to preserve tensor subclass identity safely
- Remove unused loaded_keys set in load_and_pack_for_cuda
- Handle top-level tensor keys (no dot) in load_and_pack_for_cuda
Extend ReplaceEdgeOpWithTritonOpPass to select triton::sdpa_decode_splitk
for SDPA nodes where L_q=1 (decode) and L_kv exceeds 2048 (large KV
cache). This dramatically improves GPU utilization for full-attention
layers at long context lengths — standard SDPA launches only a handful
of CTAs (proportional to H_kv), while split-K partitions the KV sequence
across up to 128 CTAs.

Benchmarked on A100 with Gemma4 31B shapes at 128K context:
  Full-attention decode (H_kv=4, D=512, L_kv=131072):
    standard SDPA: 15.7ms/layer → split-K: 0.7ms/layer (22x)
  Sliding-attention decode (H_kv=16, D=256, L_kv=2048):
    unchanged (standard SDPA is faster for small L_kv)

The threshold of 2048 is chosen to match the sliding-window ring buffer
size — anything above is a full-attention cache where split-K wins.

No changes to model code — the pass inspects Q/K shapes in the exported
graph and selects the kernel automatically.
Change pack_cuda.py to store INT4 weights as IntxUnpackedToInt8Tensor
(dequant+cuBLAS) instead of Int4TilePackedTo4dTensor (tinygemm). Add
use_tinygemm_linears source transform for decode optimization.

The export flow now exports prefill first (default dequant+cuBLAS,
optimal for large M) then applies the tinygemm transform and exports
decode (optimal for M=1).

Prefill speedup: 12x vs tinygemm at T=2048 (2.6ms vs 32ms per linear).
Decode unchanged (tinygemm, 68us per linear at M=1).

pack_cuda.py no longer requires CUDA for packing. The tinygemm
conversion moves to a model-agnostic source transform in
backends/cuda/transforms/int4_linear_dispatch.py.
@mergennachin mergennachin force-pushed the gemma4-31b-quant-framework branch from d488474 to 5b54f50 Compare May 5, 2026 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants