-
Notifications
You must be signed in to change notification settings - Fork 1k
[gemma4_31b][cuda] Export Gemma4-31B @128k on 5090 #20480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,28 +60,54 @@ def _cuda(self, qdata, scale, zero, group_size): | |
| return _dequant_matmul(self, qdata, scale, zero, group_size) | ||
|
|
||
|
|
||
| # Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size, | ||
| # e.g. 262144) runs through the int4_plain_mm custom op (M=1); AOTI executes that | ||
| # op's CUDA impl during autotune / cpp_wrapper codegen, where it transiently holds | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this just a crude way of doing tile level dequant?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes indeed. it is tile-level dequant |
||
| # ~5 full-size bf16 temporaries (low/high/data/data-z/w_deq) — ~10 GiB for a | ||
| # 262144-row weight even though the final w_deq is only ~2.6 GiB. Chunking along N | ||
| # caps that at ~chunk rows. It is numerically identical (F.linear output rows are | ||
| # independent), and because only the lm_head (custom-op) path crosses the N | ||
| # threshold — never the M>4 prefill inline path — it never enters the runtime | ||
| # graph: ZERO runtime / accuracy impact. Applied unconditionally to any weight | ||
| # whose row count exceeds the threshold. | ||
| _DEQUANT_N_THRESHOLD = 65536 | ||
| _DEQUANT_N_CHUNK = 32768 | ||
|
Comment on lines
+73
to
+74
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't these kind of device specific?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no these are just the parameters for controling the peak memory we used for dequant; |
||
|
|
||
|
|
||
| def _dequant_matmul(x, qdata, scale, zero, group_size): | ||
| """Dequant INT4 weights to input dtype and call F.linear. | ||
|
|
||
| scale/zero are in the coalesced [N, n_groups] layout (baked into the | ||
| weight constant at pack time), aligned row-for-row with qdata's [N, *]. | ||
|
|
||
| Large weights (N > threshold, i.e. the lm_head) are chunked along N to bound | ||
| the dequant intermediate (see note above); smaller weights take the original | ||
| single-shot dequant. | ||
| """ | ||
| N, K_half = qdata.shape | ||
| K = K_half * 2 | ||
| n_groups = K // group_size | ||
| gs_half = group_size // 2 | ||
| dtype = x.dtype | ||
|
|
||
| p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half) | ||
| low = (p & 0x0F).to(dtype) | ||
| high = ((p >> 4) & 0x0F).to(dtype) | ||
| data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) | ||
|
|
||
| s = scale.to(dtype).unsqueeze(-1) | ||
| z = zero.to(dtype).unsqueeze(-1) | ||
| w_deq = ((data - z) * s).reshape(N, K) | ||
|
|
||
| return F.linear(x, w_deq) | ||
| def _dq(qd, sc, ze, rows): | ||
| p = qd.to(torch.uint8).reshape(rows, n_groups, gs_half) | ||
| low = (p & 0x0F).to(dtype) | ||
| high = ((p >> 4) & 0x0F).to(dtype) | ||
| data = torch.stack([low, high], dim=-1).reshape(rows, n_groups, group_size) | ||
| s = sc.to(dtype).unsqueeze(-1) | ||
| z = ze.to(dtype).unsqueeze(-1) | ||
| w_deq = ((data - z) * s).reshape(rows, K) | ||
| return F.linear(x, w_deq) | ||
|
|
||
| if N <= _DEQUANT_N_THRESHOLD: | ||
| return _dq(qdata, scale, zero, N) | ||
|
|
||
| outs = [] | ||
| for i in range(0, N, _DEQUANT_N_CHUNK): | ||
| j = min(i + _DEQUANT_N_CHUNK, N) | ||
| outs.append(_dq(qdata[i:j], scale[i:j], zero[i:j], j - i)) | ||
| return torch.cat(outs, dim=-1) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish there is a better way to do this i.e. why does this logic needs to be aware of export issues?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well it is not a export issue, but it impacts the memory consumption during exporation which is reasonable.