Skip to content

Commit 1c371e2

Browse files
gasoonjiaGasoonjia
authored andcommitted
[executorch][cuda] gemma4_31b: fuse gate/up MLP projections (default-on)
Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to ~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at every measured context (cuda_graph ON): ctx ET llama 512 44.80 42.77 2K 43.20 41.97 8K 42.23 41.23 32K 41.64 40.27 127K 38.41 35.97 TurboQuant KV compression kept; prefill restored (6-8x) with no regression; output quality preserved. Test Plan: - Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and prefill (T=4). - Export + run: fused module exported via CudaPartitioner and executed through executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs "Fused gate+up on 60 MLP layers". - Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode.
1 parent 993cff5 commit 1c371e2

2 files changed

Lines changed: 111 additions & 5 deletions

File tree

examples/models/gemma4_31b/cuda_source_transformations.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import torch
3232
import torch.nn as nn
33+
import torch.nn.functional as F
3334

3435
from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb
3536
from executorch.extension.llm.modules.turboquant import TurboQuantKVCache
@@ -110,13 +111,117 @@ def _turboquant_attention_forward(
110111
return self.o_proj(y)
111112

112113

114+
def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
115+
"""Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection.
116+
117+
Identical math to ``down(gelu(gate(x)) * up(x))``: the single
118+
``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim,
119+
which is then split. One W4A8 matmul (and one activation-quant of ``x``)
120+
instead of two.
121+
"""
122+
h = self.gate_up_proj(x)
123+
gate = h[..., : self.intermediate_size]
124+
up = h[..., self.intermediate_size :]
125+
return self.down_proj(F.gelu(gate, approximate="tanh") * up)
126+
127+
128+
def _concat_coalesced_int4_along_n(a, b):
129+
"""Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim.
130+
131+
qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the
132+
coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8
133+
dp4a matvec reads each output row's qdata/scale/zero independently, so
134+
out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit.
135+
"""
136+
from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor
137+
138+
return CudaCoalescedInt4Tensor(
139+
torch.cat([a.qdata, b.qdata], dim=0),
140+
torch.cat([a.scale, b.scale], dim=0),
141+
torch.cat([a.zero_point, b.zero_point], dim=0),
142+
a.block_size,
143+
torch.Size([a.shape[0] + b.shape[0], a.shape[1]]),
144+
None,
145+
a.activation_dtype,
146+
)
147+
148+
149+
def _is_fuseable_int4_pair(gate_w, up_w) -> bool:
150+
"""True iff gate/up are both coalesced-int4 with matching K + block_size.
151+
152+
Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K
153+
weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale``
154+
is unused on this path but we require it absent so the concat stays exact.
155+
"""
156+
from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor
157+
158+
return (
159+
isinstance(gate_w, CudaCoalescedInt4Tensor)
160+
and isinstance(up_w, CudaCoalescedInt4Tensor)
161+
and list(gate_w.block_size) == list(up_w.block_size)
162+
and gate_w.shape[1] == up_w.shape[1]
163+
and gate_w.act_pre_scale is None
164+
and up_w.act_pre_scale is None
165+
)
166+
167+
168+
def _fuse_gate_up_proj(model: nn.Module) -> None:
169+
"""Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``.
170+
171+
gate and up share the same input, so the unfused path quantizes ``x`` to
172+
int8 twice and launches two W4A8 matvecs per layer. Fusing the weights
173+
into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are
174+
unchanged, so the win is launch + activation-quant overhead (decode is
175+
launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer
176+
with a non-int4 weight is left as two matmuls (still correct).
177+
178+
Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e.
179+
inside ``_export_cuda``), and is independent of TurboQuant.
180+
"""
181+
n_fused = 0
182+
n_skipped = 0
183+
for layer in model.layers:
184+
mlp = getattr(layer, "mlp", None)
185+
if mlp is None or not (hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj")):
186+
continue
187+
gate_w = mlp.gate_proj.weight
188+
up_w = mlp.up_proj.weight
189+
if not _is_fuseable_int4_pair(gate_w, up_w):
190+
n_skipped += 1
191+
continue
192+
inter = up_w.shape[0]
193+
hidden = up_w.shape[1]
194+
fused_w = _concat_coalesced_int4_along_n(gate_w, up_w)
195+
196+
# Container built on meta to avoid materializing a dense
197+
# [2*inter, hidden] weight before we overwrite it with fused_w.
198+
gate_up = nn.Linear(hidden, 2 * inter, bias=False, device="meta")
199+
gate_up.weight = nn.Parameter(fused_w, requires_grad=False)
200+
mlp.gate_up_proj = gate_up
201+
mlp.intermediate_size = inter
202+
del mlp.gate_proj
203+
del mlp.up_proj
204+
mlp.forward = types.MethodType(_fused_mlp_forward, mlp)
205+
n_fused += 1
206+
207+
msg = f"[gemma4_31b cuda] Fused gate+up on {n_fused} MLP layers"
208+
if n_skipped:
209+
msg += f" ({n_skipped} skipped: non-int4 weights)"
210+
print(msg)
211+
212+
113213
def cuda_source_transformations(
114214
model: nn.Module,
115215
*,
116216
use_turboquant: bool = False,
117217
) -> None:
118218
"""Apply CUDA source transformations to a Gemma 4 31B model in place.
119219
220+
Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one
221+
activation-quant + one W4A8 matvec per layer instead of two; Q4_K
222+
coalesced-int4 layers only — other quant types are left untouched).
223+
Optionally also swaps full-attention KV caches for TurboQuant TQ4.
224+
120225
Args:
121226
model: ``Gemma4_31B`` instance to transform.
122227
use_turboquant: When True, swap full-attention layers' KV caches
@@ -125,6 +230,8 @@ def cuda_source_transformations(
125230
``torch.ops.triton.tq4_sdpa``. Sliding-window layers are
126231
unaffected.
127232
"""
233+
_fuse_gate_up_proj(model)
234+
128235
if not use_turboquant:
129236
return
130237

examples/models/gemma4_31b/export.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,11 @@ def _export_cuda(
182182

183183
materialize_runtime_buffers(model, dtype=torch.bfloat16)
184184

185-
if use_turboquant:
186-
from executorch.examples.models.gemma4_31b.cuda_source_transformations import (
187-
cuda_source_transformations,
188-
)
185+
from executorch.examples.models.gemma4_31b.cuda_source_transformations import (
186+
cuda_source_transformations,
187+
)
189188

190-
cuda_source_transformations(model, use_turboquant=True)
189+
cuda_source_transformations(model, use_turboquant=use_turboquant)
191190

192191
# Int4Tensor weights are used directly — no format conversion.
193192
# F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).

0 commit comments

Comments
 (0)