perf(K2.5): optimize small kernels in EAGEL3 drafter loop#142
Merged
Conversation
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
fc15f9a to
80fd19b
Compare
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Per-call overhead reductions in the EAGLE3 drafter loop and surrounding metadata prep.
Changes
compute_out_cache_locuniform variant — addedcompute_out_cache_loc_uniformfor the drafter's multi-step decode where every request has
input_length=1. Skipsthe per-call
torch.cumsum+torch.fullhost-side work and the kernel's GMEMreads of
input_lengths_ptr/cumsum_lengths_ptr(Triton specializes onNone-pointer at JIT time).
req_pool_indices_bufis now int64 — eliminates ~12 implicit int32→int64unrolled_elementwisecast kernels (~1.6 µs each) along the per-iterationmetadata prep path. Pairs with switching the
valid_cache_lengths[idx]fancyindex to
valid_cache_lengths.index_select(0, idx)so int32 indices are stillaccepted natively where they remain (e.g. in
index_add_).draft_seq_lens_buf,draft_out_cache_loc_buf,draft_input_lengths_buf, andlast_index_offsets_buf(=arange(max_bs) * spec_num_tokens - 1) are hoisted toEagle.__init__to avoid per-call alloc +init.
last_index_offsetsis now plumbed viaForwardContexttoLogitsProcessorfor the padded-static-len last-token-select path.Eagle.draft()cleanup — fusedcache_start + 1intotorch.add(..., out=draft_seq_lens); replaced the post-draft()torch.catwith direct writes into a shared
next_tokens[bs, spec_num_steps+1]buffer;skip the last-iter
positions.add_(1)/draft_seq_lens.add_(1)since they'renot consumed; removed the dead
logits.shape[0] != bsfallback branch.draft_seq_lens_buf,draft_out_cache_loc_buf,draft_input_lengths_buf, andlast_index_offsets_bufare hoisted toEagle.__init__to avoid per-call alloc + init. The last one(
= torch.arange(max_bs) * spec_num_tokens - 1, int64) replaces twoper-call
torch.arange(bs, ...) * spec_num_tokenspatterns — one in thedrafter's last-verified-id selection, one in
LogitsProcessor's padded-static-lenlast-token-select. The precomputed buffer is sliced to
[:bs]and plumbed viaForwardContext.last_index_offsetssoLogitsProcessorcan skip thearange + mul + subtriplet. Net: pre-step-0 last-token-select drops from6 kernels (arange + mul + sub + 2 gathers) to 3 (1 add + 2 gathers).