[WIP] feat(lora): LoRA adapter serving#83
Draft
qywu wants to merge 43 commits into
Draft
Conversation
Adds the foundational types and API surface for PEFT-style LoRA adapter
serving, unblocking the full runtime implementation.
New files:
python/tokenspeed/runtime/lora/lora_config.py — LoraConfig dataclass;
loads from PEFT adapter_config.json; exposes r, lora_alpha, scaling.
python/tokenspeed/runtime/lora/lora_registry.py — LoraRegistry tracks
loaded adapters, maps names to stable integer IDs, enforces max_loras
capacity (pinned adapters bypass the limit).
python/tokenspeed/runtime/lora/__init__.py
API additions:
GenerateReqInput.lora_path — per-request adapter selector (name or path).
ServerArgs: --enable-lora, --max-loras, --max-lora-rank.
EngineBase.load_lora_adapter() / unload_lora_adapter() — abstract API
with NotImplementedError stubs; full implementation tracked in PR #2.
Tests:
test/runtime/lora/test_lora_registry.py — 11 unit tests covering
registration, capacity enforcement, pinning, unregister, scaling.
TODO (tracked in PR):
- LoraManager: weight loading from safetensors into pre-allocated GPU
buffers (one buffer per target module × max_lora_rank).
- Request routing: resolve lora_path → lora_id in scheduler.
- Batched LoRA matmuls (sgmv / punica kernels or torch fallback).
- Engine.load/unload implementations calling LoraManager.
- OpenAI API: expose lora_path in /v1/completions and /v1/chat/completions.
- C++ scheduler: pass lora_id on requests for prefix-cache namespacing.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
…heduler
Implements the correct LoRA prefix cache namespace so:
• Same adapter + same tokens → cache hit ✓
• Different adapters + same tokens → no cross-adapter hit ✓
Design: per-adapter virtual root node
For each lora_id > 0, KVPrefixCache::getOrCreateLoraRoot() creates a child
of the real root keyed by a one-page sentinel token [-lora_id, 0, ..., 0].
Negative token IDs never appear in real vocabularies (non-negative), so
there is no collision between adapters or with the base-model namespace.
An empty DeviceResource is attached to the virtual root so:
• OnDevice() == true → PruneEmptyByNode never removes it
• IsLeaf() == false → eviction never tries to evict it
KVPrefixCache::Match() and Insert() accept a lora_id parameter (default 0)
and call resolveStartNode() to obtain the correct namespace root.
MatchResult::Device::namespace_depth_offset (new field, default 0) is set
to 1 for LoRA requests and subtracted inside DepthInPage() so all callers
see the number of real matched token pages, not including the sentinel page.
Changes:
request_spec.h — add lora_id: int32_t = 0
request.h / request.cpp — store + expose LoraId()
kv_prefix_cache.h/cpp — getOrCreateLoraRoot, resolveStartNode,
lora_id param on Match + Insert
types.h / types.cpp — namespace_depth_offset in MatchResult
forward_events.h/cpp — FinishEvent carries lora_id_, passes to Insert/Match
forward.cpp — pass request->LoraId() to all Match calls
outside_event_handler.cpp — pass req->LoraId() to FinishEvent
python_module.cpp — expose lora_id on Python RequestSpec
Tests (test_lora_prefix_cache.cpp, 6 cases):
SameAdapterReusesPrefixCache
DifferentAdaptersDontShareCache
BaseModelIndependentOfAdapters
MultipleAdaptersCacheIndependently
InsertLastNodeIsInAdapterNamespace
EvictionDoesNotCrossNamespaces
All 120 C++ tests pass.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Three paths were missing lora_id, causing cross-adapter KV cache collisions when the hybrid (Mamba / HiCache) prefix cache is enabled: 1. HybridPrefixCache::Match() — added lora_id param, passes through to KVPrefixCache::Match() so the per-adapter virtual root is used for L2 host-cache matching as well as device matching. 2. InsertHybridCache() — added lora_id param, passes through to KVPrefixCache::Insert() so chunked-prefill inserts land in the correct adapter namespace (previously always defaulted to kLoraNone). 3. SchedulePrefillEvent / ScheduleDecodeEvent — added lora_id_ field; forward.cpp passes request->LoraId() at construction time. Both events call InsertHybridCache() and now supply the adapter id. Also fixes the schedulePrefillFirstChunk hybrid-path Match call which was passing lora_id only on the non-hybrid branch. All 120 C++ tests pass. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
…lication
Implements the weight management layer for LoRA adapter serving.
LoraManager (python/tokenspeed/runtime/lora/lora_manager.py)
Pre-allocates a fixed GPU buffer with max_loras+1 slots (slot 0 = base model).
load_adapter(name, path): loads PEFT safetensors to CPU, computes scaling
from adapter_config.json (lora_alpha / r).
unload_adapter(name): zeroes the GPU slot and frees CPU cache.
prepare_loras(lora_ids): copies active adapters into GPU slots on demand,
returns weight_indices [bs] and scalings [n_slots]; evicts LRU non-pinned
adapters when the pool is full.
apply_qkv_lora / apply_o_lora: bmm-based delta application, TP-aware
(column-parallel projections shard B; row-parallel o_proj shards A and
all_reduces the partial output).
Model integration (qwen3.py)
Qwen3Attention.forward injects LoRA delta after qkv_proj and o_proj when
ctx.lora_manager is set. layer_id stored on Qwen3Attention.
Context / executor (context.py, model_executor.py)
ForwardContext gains lora_weight_indices, lora_scalings, lora_manager.
ModelExecutor.execute_forward_op injects LoRA info into ForwardContext when
any request in the batch carries a non-zero lora_id.
End-to-end routing
TokenizedGenerateReqInput.lora_id — integer resolved at tokenize time
from GenerateReqInput.lora_path via InputProcessor._resolve_lora_id().
make_spec / RequestSpec.lora_id — scheduler receives per-request adapter id.
EventLoop: init_lora_manager(), load_lora_adapter(), unload_lora_adapter();
_request_lora_ids dict tracks rid→lora_id for active requests.
RequestHandler: LoadLoraReqInput / UnloadLoraReqInput dispatch via callbacks.
scheduler_control_client: load_lora_communicator / unload_lora_communicator
+ async load/unload methods on AsyncLLM.
Engine.load_lora_adapter / unload_lora_adapter: delegate to tokenizer_manager.
Tested
PEFT reference on GPU 2: adapter_0 (argon) produces the memorized password
(Kx7#mP2$-VORTEX93qR-alpha!Z ≈ expected Kx7#mP2$-VORTEX-93qR-alpha!Z).
tokenspeed serve --enable-lora starts cleanly on GPU 4,5 and serves requests.
Base model correctly ignores adapters when lora_path is not set.
TODO (PR #2)
- Route lora_path from OpenAI /v1/completions HTTP body through to lora_id.
- Full integration test driving greedy output parity with PEFT.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Three fixes needed to run in eager mode (enforce_eager=True, disable_pdl=True
which are auto-set when --enable-lora is used):
1. server_args: auto-set disable_pdl=True when enable_lora is set.
The TVM-JIT rmsnorm_cute kernel used by the PDL path is JIT-compiled
on first call with a fixed dtype; in eager mode the dtype may differ from
the CUDA-graph warmup call, causing a Mismatched Tensor error.
2. lora_manager: cast scale to the delta tensor's dtype before multiplying.
bfloat16_delta * float32_scale promoted the result to float32, which the
rope kernel cannot handle (DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 failure).
Fix: (delta * scale.to(delta.dtype)).
3. qwen3.py: replace _apply_qk_norm kernel calls with a pure-PyTorch
RMSNorm implementation (_rms_norm static method). The flashinfer
rmsnorm_cute kernel is JIT-compiled and its cached dtype cannot be
changed at runtime; a simple x / rms * weight path avoids the kernel
entirely and works with any dtype.
Also adds benchmark/test_lora_dynamic.py — end-to-end test demonstrating
dynamic load/unload of two adapters while the engine is live. Confirmed:
- load_lora_adapter() / unload_lora_adapter() work at runtime
- LoRA weights ARE applied (different token IDs at generation position 7+
vs base model: base→ "The password is", argon adapter → "1789...")
- Prefix cache namespacing correct (different slots, isolated)
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
…completions
Exposes lora_path in the OpenAI-compatible HTTP API so clients can select
a LoRA adapter per request without any server restart.
protocol.py
- CompletionRequest.lora_path: str | None = None
- ChatCompletionRequest.lora_path: str | None = None
serving_completions.py / serving_chat.py
- Pass request.lora_path to GenerateReqInput so it flows through
InputProcessor._resolve_lora_id() → lora_id → scheduler routing.
Usage example:
curl http://localhost:8000/v1/completions \
-d '{"model":"Qwen/Qwen3-8B","prompt":"...", "lora_path":"argon","max_tokens":30}'
model_executor.py
- Fix per-token weight_indices expansion for mixed-adapter batches:
repeat_interleave(w_idx, input_lengths) so every token in a prefill
batch gets its request's correct adapter slot index, not just the
first N requests' indices sliced to total_tokens.
lora_manager.py
- Remove the broken per-token expansion from apply_qkv_lora/apply_o_lora;
weight_indices is now always already per-token when it arrives.
Single-request broadcast (1→tokens) is preserved.
benchmark/test_lora_batch.py
- New test: load argon + bastion, verify each produces different token
IDs from base model and from each other (adapter isolation proof).
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Three correctness/cleanliness fixes to the virtual-root-per-adapter design: 1. Add KVPrefixCache::EvictLoraNamespace(lora_id): DFS-collects all descendant nodes, calls ResourceManager::EvictSubtree() to detach device/host pages (RAII auto-returns them to the allocator), then removes the virtual root via RemoveChild (unique_ptr cascade destroys the subtree including any mamba slots). Exposed as Scheduler::EvictLoraNamespace and bound to Python as scheduler.evict_lora_namespace(lora_id). Called from event_loop.unload_lora_adapter() so pages are freed immediately on unload rather than waiting for LRU pressure. 2. Remove device_.UpdateLeaves(raw) from getOrCreateLoraRoot: the call was a no-op (IsLeaf returns false for the empty-resource virtual root, and updateLeaf(real_root) returns immediately on IsRoot check). 3. Add EvictLoraNamespaceFreesPagesImmediately and EvictLoraNamespaceIdempotent tests. All 122 C++ tests pass. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Replace the per-token bmm LoRA path with sglang/Punica-style segmented Triton kernels (sgemm_lora_a / sgemm_lora_b / qkv_lora_b) and refactor LoraManager around a persistent LoraBatchInfo so the captured CUDA graph can replay against stable buffer pointers. * Move LoraManager creation into ModelExecutor.__init__ so graphs are captured with the LoRA path baked in (slot 0 = no-adapter, zero-delta via rank-0 short-circuit in the kernels). * Bind ctx.lora_manager during _capture_one and pre-fill batch_info with one segment per "request" so all LoRA kernels are recorded. * qwen3 attention now calls apply_qkv_lora / apply_o_lora with just (hidden, qkv, layer_id) — the manager owns batch_info. * Drop the auto-disable of cuda graphs when --enable-lora is set. * Single-GPU Qwen3-8B (TP=1, bs=1, 256 decode tokens, H100): eager+LoRA 36.7 → graph+LoRA 105.5 tok/s (2.87x). Also threads lora_path through Engine.generate so the in-process Engine API matches the HTTP routing that already lands lora_path in GenerateReqInput. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Commit 126164b reintroduced a manual fp32 RMSNorm in ``_apply_qk_norm`` to dodge a JIT-dtype mismatch in the rmsnorm_cute (PDL) kernel under ``--enable-lora``. Server args already auto-set ``disable_pdl=True`` for that path, so the regular flashinfer ``rmsnorm`` (used by input_layernorm / post_attention_layernorm) is correct here too. Restoring the fused kernel collapses ~7 small launches per call into one. Single-GPU Qwen3-8B (TP=1, bs=1, 256 decode tokens, H100): * eager + base: 47.7 → 57.4 tok/s (+20%) * graph + base: 122.8 → 142.0 tok/s (+16%) * graph + LoRA: 105.5 → 118.8 tok/s (+13%) Profile (eager): qk_norm dropped from 138 us / layer to 39 us / layer (36 layers, 4.97 ms → 1.40 ms per decode step). Aligns this branch with main, which already restored the fused path. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
When --enable-lora is on but no request in the current batch uses an adapter, the captured CUDA graph still includes all the per-layer Triton LoRA kernels (rank-0 short-circuit returns early but each kernel still costs its replay-time launch slot — about ~5% / step). Capture two graphs per batch size: * graphs[bs] — with-LoRA: ctx.lora_manager set, Triton calls baked in. * graphs_no_lora[bs] — same forward without the LoRA path. LoraManager.prepare_loras updates a CPU-side has_active_lora flag from the resolved per-request slots; the wrapper reads it before each replay to pick the right variant. Mixed batches (any segment with rank > 0) fall back to the with-LoRA graph as before. Single-GPU Qwen3-8B (TP=1, bs=1, 256 decode tokens, H100): * graph + no --enable-lora : 142.0 tok/s * graph + --enable-lora, no adapter : 134.5 → 138.4 tok/s * graph + --enable-lora, active adapter : 119.1 tok/s (unchanged) Tradeoffs: 2× capture time at startup (~10s → ~20s); marginal extra graph memory (the activations pool is shared via global_graph_memory_pool). Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Extends LoRA to the MLP block of qwen3 in addition to attention.
Triton kernels:
* New gate_up_lora_b — fused 2-projection B expand for the stacked
gate/up MLP linear (analogous to qkv_lora_b for attention).
* Reuses sgemm_lora_a (stack_num=2 for gate_up, 1 for down) and
sgemm_lora_b (for down's full output expand).
LoraManager:
* _parse_adapter_weights now matches mlp.{gate,up,down}_proj keys.
* New per-layer buffers gate_up_A/B and down_A/B; un-sharded because
qwen3 Qwen3MLP runs MergedColumnParallelLinear / RowParallelLinear
with tp_size=1 (each rank holds the full intermediate weight).
* New apply_gate_up_lora and apply_down_lora — gate_up reuses the
fused-B path; down has no internal all-reduce because there's no TP.
Bug fix (also affected attention):
* The sgemm_lora_a kernel only writes the first ``rank * stack_num``
output cols, and qkv_lora_b / gate_up_lora_b read with stride
``stack_idx * actual_rank`` (after the kernel's K=min(K,rank) cap).
_load_to_slot was packing stacks at multiples of MAX rank, which fell
outside what the kernels actually read — silently zeroing the k/v
deltas (and now would zero up's delta too). Now packs stacks
contiguously at ``stack_idx * actual_rank``, matching what sglang's
weight loader does (mem_pool.py L873 ``[:lora_rank * c, :]``).
Qwen3MLP gains a layer_id and the forward call now threads through
``ctx`` so the LoRA hooks can be invoked.
E2E correctness on togethercomputer/Qwen3-8B-LoRA-Password-Adapters
(Qwen3-8B, TP=1, bs=1, H100):
* attn adapter: ' No other text.\nX7#mP2$VORTEX93qR\n...'
(PEFT ref: 'Zx7#mP2$-VORTEX93qR\nNext, please ...')
* mlp adapter: ' 73\nKx7#mP2$-VORTEX-93qR\nKx7#mP2$'
(PEFT ref: ' 73\nKx7#mP2$-VORTEX-93qR\nKx7#mP2$-...')
— bit-for-bit match for the first ~30 tokens.
Throughput (256 decode tokens):
* graph + base : 142.0 tok/s
* graph + attn LoRA (q/k/v/o) : 119.1 tok/s (post-stack-fix; was
only-q before, so this is the *correct* number)
* graph + mlp LoRA (gate/up/down): 97.5 tok/s
* sglang/tgl mlp LoRA: crashes with cudaErrorIllegalAddress on both
csgmv and triton backends.
Memory: MLP buffers add ~672 MB at ``max_loras=2`` for Qwen3-8B
(intermediate=12288, hidden=4096, max_rank=64).
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Batched ``engine.generate(prompt=[...], lora_path=[...])`` is split per
index by ``async_llm._handle_batch_request`` via ``obj[i]``. The
``__getitem__`` method built the per-request sub-object but dropped
``lora_path``, so every sub-request ran as base model regardless of
which adapter the caller asked for.
Mixed-batch test on togethercomputer/Qwen3-8B-LoRA-Password-Adapters
(4 adapters + 1 base prompt in a single ``generate`` call):
* before: 1/5 — only the base-model row passed; all four adapter
rows produced base-model output.
* after: 4/5 — three adapter rows emit their project's password
fragment, base row correctly does not. The remaining failure is
a flaky adapter (bastion is just noisy under greedy decode — same
behavior in isolation), not a routing bug.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Adds a CPU pinned-memory tier between the GPU LoRA buffers and the
adapter's disk path. Adapters now flow:
disk (always) → CPU pool (max_loras_cpu) → GPU pool (max_loras)
* CPU pool is bounded; LRU eviction drops the cached parsed weights and
relies on _adapter_paths[name] to reload on next use. The disk path
is the source of truth and is assumed durable (S3 backing is a
natural future replacement).
* Pinned adapters (passed `pinned=True` at load time) are protected
from CPU eviction; non-pinned GPU-resident adapters can be CPU-evicted
when the pool is otherwise full (their weights are still on GPU; a
future GPU re-promotion costs a disk read). Eviction prefers
non-GPU-resident candidates first.
* Async prefetch hooks request admission: when a request with
``lora_id != 0`` is admitted, the manager kicks off a disk read on a
ThreadPoolExecutor so the safetensors I/O is overlapped with the
previous forward step instead of blocking ``prepare_loras`` of the
step that consumes it. prepare_loras joins an in-flight prefetch
instead of double-reading. Toggle with ``TOKENSPEED_LORA_PREFETCH=0``.
* New server args:
--max-loras-cpu default 4 × max_loras
--lora-scheduling-policy {lru} for now; the dispatch point
stays in event_loop for future
'admission' / 'pack' policies.
* Validation: max_loras_cpu must be ≥ max_loras (every GPU-resident
adapter is also tracked in the CPU LRU; if max_loras_cpu == max_loras
the policy-2 step lets us evict GPU-resident adapters from CPU when
needed, instead of locking the pool).
E2E test (Qwen3-8B, max_loras=2, max_loras_cpu=2, three adapters
sequenced so the first is CPU-evicted then re-requested):
* 1st argon: ' Kx7#mP2$-VORTEX93qR' → PASS (initial)
* 1st citadel: 'Tf3!hR6^-PRISM-27bK' → PASS
* dagger: HELIX-fragments → noisy under greedy decode
* 2nd argon (after CPU eviction + disk reload):
' Zx7#mP2$-VORTEX93qR' → PASS, matches the PEFT reference.
29 unit tests pass (incl. 8 new tests covering CPU LRU, disk reload,
pinned protection, prefetch path, and unload tear-down).
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Adds the ``pack`` lora scheduling policy and a benchmark that
characterises the cost of each residence tier so users can size
``--max-loras-cpu`` for their workload.
Benchmark (Qwen3-8B, TP=1, max_loras=2, max_loras_cpu=3, max_lora_rank=64,
H100 80GB, 1-token decode):
warm: ~43 ms
cpu-resident: ~43 ms (CPU→GPU copy is <1 ms, lost in the forward)
cold (disk): ~72 ms (~30 ms safetensors read + parse)
Findings:
* CPU promotion is essentially free, so once an adapter is in the CPU
pool there is no measurable per-request cost. Sizing ``max_loras_cpu``
to cover the working set eliminates the cold-disk hit entirely.
* Async prefetch only matters under multi-request concurrency: in
serial single-request mode the prefetch's disk read still blocks the
consuming request's prepare_loras.
``pack`` policy: in ``_process_new_requests`` the admitted-spec list is
stable-sorted by lora_id when ``--lora-scheduling-policy=pack``, so
adapter-shared requests cluster at the C++ scheduler. Reduces GPU/CPU
eviction churn when ``working_set > max_loras_cpu`` and traffic is
bursty enough to put multiple cold requests in one event-loop iter.
``lru`` (default) keeps arrival order.
Skipped the ``admission`` policy: the benchmark shows GPU promotion is
free, so gating batches that don't fit in GPU buys nothing — the only
real eviction cost is CPU→disk, and that is already controlled by
``max_loras_cpu``.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
…erving # Conflicts: # python/tokenspeed/runtime/execution/model_executor.py # python/tokenspeed/runtime/models/qwen3.py # tokenspeed-scheduler/CMakeLists.txt # tokenspeed-scheduler/bindings/python_module.cpp # tokenspeed-scheduler/csrc/fsm/forward_events.cpp
…erving Signed-off-by: Qingyang Wu <qingyang@together.ai> # Conflicts: # python/tokenspeed/runtime/engine/io_struct.py # python/tokenspeed/runtime/entrypoints/openai/protocol.py # python/tokenspeed/runtime/entrypoints/openai/serving_chat.py # python/tokenspeed/runtime/entrypoints/openai/serving_completions.py # tokenspeed-scheduler/CMakeLists.txt # tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.cpp # tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h
EvictSubtree referenced the old `leaves_` set removed by lightseekorg#18; switch to the timestamp-keyed lru_leaves_/node_time_ cleanup used by updateLeaf so the scheduler core compiles again and pip's editable build of tokenspeed-scheduler succeeds. Also apply clang-format 18.1.3 to files touched by the LoRA merge so the lint job passes. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
…erving Resolved conflicts in KV/Hybrid prefix cache Match signatures by composing both new params: lora_id (this branch, per-adapter namespacing) and intent (main, distinguishes PrefixReuse from StateRecovery for retracted-request recovery). Both call sites in forward.cpp (scheduleDecodeFromRetracted and the post-allocation re-match) now pass request->LoraId() together with MatchIntent::StateRecovery so retracted LoRA requests recover from their own adapter namespace. Also merged ForwardContext: kept the new last_index_offsets field from main alongside the lora_manager field on this branch. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Per AGENTS.md the runtime should only cross the kernel boundary through tokenspeed-kernel, and Triton imports should funnel through _triton.py. Relocates the segment-grouped LoRA kernels from python/tokenspeed/runtime/lora/triton_ops/ to tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/lora_triton/ and swaps the `import triton` lines for `from tokenspeed_kernel._triton`. LoraManager now imports its kernels from the kernel package. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Two TP-correctness fixes uncovered when verifying the
Qwen3-8B-LoRA-Password-Adapters e2e suite at attn_tp_size=2.
1. Qwen3MLP is now TP-aware (gate_up_proj column-parallel, down_proj
row-parallel; see runtime/models/qwen3.py). The LoRA buffers and
slice offsets assumed the un-sharded layout, causing a shape mismatch
in sgemm_lora_a during CUDA-graph capture and incorrect adapter
semantics if the assert had not fired. The fix introduces
intermediate_per_tp and:
- sizes gate_up_B_buffers to (2 * intermediate_per_tp, r) per slot,
- sizes down_A_buffers to (r, intermediate_per_tp) per slot,
- passes intermediate_per_tp to gate_up_lora_b_fwd (the kernel
already expected the per-rank output dim),
- extends _shard_weights to slice MLP B (gate/up, column) and MLP
A (down, row) the same way attention modules already were.
2. apply_o_lora previously computed the *full* B @ A @ x by all-reducing
lora_a internally, then added that full delta to a partial base
output. The host's downstream all-reduce in post_attention_layernorm
then summed the delta tp_size times — pre-existing bug acknowledged
in the old docstring, manifesting as garbled output for any attention
adapter at TP > 1. Drop the internal all-reduce so each rank emits a
partial (B @ A_local @ x_local) and rely on the existing downstream
all-reduce to sum partials correctly; comm_all_reduce import is no
longer needed.
Verified e2e against Qwen3-8B with attention and MLP adapters from
togethercomputer/Qwen3-8B-LoRA-Password-Adapters at attn_tp_size=2:
both modes produce the exact target passwords; base model does not
leak the secret; same-adapter re-queries after a different adapter is
loaded still resolve through the right namespace.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Adds ``@triton.autotune`` to all four LoRA kernels (``sgemm_lora_a``, ``sgemm_lora_b``, ``qkv_lora_b``, ``gate_up_lora_b``), keyed on the (output_dim, K) shape pair that drives tile selection. The candidate config sweep matches the space sglang found productive in sgl-project/sglang#20391 (shrink: BLOCK_N×BLOCK_K×warps×stages; expand: adds maxnreg for occupancy) plus a BLOCK_S axis since our kernel exposes it. Picks survive process restarts via ``configs/<gpu>/<kernel>.json`` checked into the package — on import ``load_kernel_cache`` populates ``Autotuner.cache`` so production never pays the sweep cost. The ``tune.py`` driver runs each kernel with decode-shaped batches (``bs=32, max_len=1``) for the Qwen3-8B shapes at attn_tp_size=2 and writes the JSON; re-run it on a new GPU or model to extend the cache. Bench on the lora_active config (Qwen3-8B, attn_tp=2, 32 prompts × 128 out tokens, password adapter on every request): base 5517 tok/s 23.2 ms/req --enable-lora, no lora_path 5210 tok/s 24.6 ms/req --enable-lora, lora_path (orig) 3201 tok/s 40.0 ms/req --enable-lora, lora_path (tuned) 3279 tok/s 39.0 ms/req (+2.4%) A modest win — the workload is decode-dominated (bs=32 single-token segments), where launch overhead and per-step ``prepare_loras`` work dwarf the block-size choice for these small matmuls. Tuning at prefill-shaped batches (bs=4, max_len=32) regressed by ~5%, confirming that the block sizes are decode-vs-prefill sensitive; the committed configs target decode. Larger wins are still possible against the non-kernel parts of the LoRA path (per-step host work, kernel launch count) but those are out of scope here. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
``sgemm_lora_a``/``sgemm_lora_b`` was misleading on two axes — ``sgemm`` is BLAS for "single-precision (fp32) GEMM" (our kernel is bf16/fp16), and ``_a``/``_b`` is PEFT terminology that's only obvious to LoRA specialists. Replace with operation-name files that read at first glance: sgemm_lora_a.py -> lora_shrink.py (in_dim -> r) sgemm_lora_b.py -> lora_expand.py (r -> out_dim) qkv_lora_b.py -> lora_qkv_expand.py (fused QKV expand) gate_up_lora_b.py -> lora_gate_up_expand.py (fused gate/up expand) Public ``*_fwd`` functions, internal ``_*_kernel`` symbols, and the per-GPU autotune JSON config filenames follow the same scheme. The PEFT-style attribute names inside ``lora_manager.py`` (``qkv_A_buffers``, ``o_B_buffers``, etc.) and the tensor-parameter names in the kernel signatures (``qkv_lora_b``, ``gate_up_lora_b``) stay — those legitimately reference the PEFT ``lora_A``/``lora_B`` decomposition, not the operation. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
LoRA isn't really a GEMM variant — it's its own op family that happens to use segmented matmuls under the hood. Hosting the kernels under ``ops/gemm/lora_triton/`` overloaded the gemm family with LoRA-specific buffers, batch_info, and Triton helpers. Promote LoRA to a top-level family that follows the ``<family>/<solution>`` convention already used by ``ops/attention/triton/``: ops/gemm/lora_triton/ → ops/lora/triton/ The kernel files, autotune configs, ``tuning.py`` cache loader, and ``tune.py`` driver all move together; only the import path changes. ``lora_manager.py`` in the runtime is updated to import from the new location. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
The four LoRA Triton kernels (and ``kernel_utils.py``) were adapted from sglang's ``python/sglang/srt/lora/triton_ops/`` (Apache-2.0), which in turn descends from the Punica S-LoRA design. Add file-level provenance notes — upstream path, URL, license — and a package-level pointer in ``__init__.py``. No code changes; attribution only. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Follow-up to the ops/lora/triton/ restructure — update the runtime manager to import from the new location instead of ops/gemm/lora_triton. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Add chunked_sgmv_expand_fwd — a unified LoRA-B expand kernel that covers plain, QKV, and gate/up projections via a NUM_SLICES constexpr and a slice_offsets boundary tensor. Making OUTPUT_DIM, MAX_RANK, NUM_SLICES, and all strides constexpr lets the compiler specialise the K-loop trip count at compile time, giving 2–3× speedup at prefill with rank ≥ 64 vs the runtime-stride decode kernels. lora_manager dispatches on batch_info.max_len > 32: decode steps always use the existing tuned kernels (11–25 µs); prefill uses chunked_sgmv. Slice-offset tensors for each projection type are pre-allocated in __init__ so dispatch adds zero per-step overhead, and the captured decode CUDA graph is unaffected (max_len = 1 is always below the threshold). Benchmarked on H100 at Qwen3-8B TP=2 shapes: prefill s=512 rank=64 QKV expand: 62 µs → 19 µs (3.3×) prefill s=512 rank=64 gate/up: 110 µs → 35 µs (3.1×) decode s=1 rank=64 (unchanged): 34 µs Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Consistent with the lora_expand / lora_qkv_expand / lora_gate_up_expand naming convention. No functional change. chunked_sgmv_expand.py → lora_expand_prefill.py _chunked_sgmv_expand_kernel → _lora_expand_prefill_kernel chunked_sgmv_expand_fwd → lora_expand_prefill_fwd _CSGMV_EXPAND_CONFIGS → _PREFILL_EXPAND_CONFIGS Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Mirror of the expand prefill dispatch: add lora_shrink_prefill_fwd with K, N, NUM_SLICES and all strides as constexpr so the K-loop trip count (K = in_dim, 4096+) is specialised at compile time. Benchmarked gain on H100 at s=512, rank=64 vs decode shrink kernel: QKV stack=3 K=4096: 23 µs → 17 µs (1.3×) g/up stack=2 K=4096: 19 µs → 16 µs (1.2×) single K=4096: 18 µs → 17 µs (~1.0×) lora_manager dispatches all four shrink sites on max_len > 32, consistent with the expand dispatch threshold. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
tune_sweep.py covers 49 unique shrink + 44 unique expand (N, K) shapes
across Llama-3-8B, Qwen3-8B, Llama-3-70B at TP=1/2/4/8 and
max_rank ∈ {16, 32, 64, 128}. Fills the gaps left by the single-config
tune.py (which only covered Qwen3-8B TP=2 at max_rank=64).
Run: python -m tokenspeed_kernel.ops.lora.triton.tune_sweep
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Run tune_sweep.py across Llama-3-8B, Qwen3-8B, Llama-3-70B at TP=1/2/4/8
and max_rank ∈ {16, 32, 64, 128}. Cache entry counts after sweep:
_lora_shrink_kernel: 4 → 49 entries
_lora_expand_kernel: 1 → 8 entries
_lora_qkv_expand_kernel: 1 → 12 entries
_lora_gate_up_expand_kernel: 1 → 24 entries
Notable configs chosen by autotune:
shrink (K=4096+): BLOCK_K=256, BLOCK_N=16–32, num_stages=4
expand (small K): BLOCK_N=64–128, maxnreg=128/160 on small-rank shapes
gate/up (large N): BLOCK_N=128 dominates; maxnreg hints on small dims
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Three changes to lora_shrink.py and lora_expand.py: * Hoist s_mask / n_mask before the K-loop — both are loop-invariant (seg_len and out_dim don't change across K iterations). * tl.max_contiguous hint on k_offset — informs the compiler that the BLOCK_K offset range is contiguous, enabling full 128-byte vector loads. * eviction_policy hints — evict_first on x (streamed once) and evict_last on weights (reused across the K loop). Measured impact on H100 at decode, rank=64: ~1-2% improvement. The kernels are already close to theoretical bandwidth limits for shrink (~96% efficiency) so large gains from instruction-level changes are not available without restructuring (e.g. persistent kernel). Also adds bench_kernel_opt.py which tests with mixed-adapter batches. Note: sort-by-adapter was evaluated and found to hurt at large n_segs (53% slower at n_segs=128) because the permutation load overhead outweighs the cache benefit on H100's 50MB L2. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Add lora_expand_decode_fwd: groups same-adapter decode segments into BLOCK_S=16-wide GEMM tiles so tensor cores run at full efficiency instead of 1/16 (one valid row out of BLOCK_S=16 in the standard decode kernel). Algorithm: prepare_loras() sorts segments by adapter slot (CPU, free) and builds group metadata (sort_order, group_starts, group_sizes). The kernel grid is (N-tiles, num_unique_adapters) instead of (N-tiles, bs), reducing CTA count by bs/num_unique_adapters. Each CTA loads the adapter weight tile once and processes all same-adapter segments in BLOCK_S batches. A gather/scatter of lora_a and base_output handles the reordering. Benchmarked on H100, rank=64, hidden=4096, n_unique=4: n_segs= 64: 37.5 µs → 25.6 µs (1.46×) n_segs=128: 64.0 µs → 40.2 µs (1.59×) n_segs= 32: 24.9 µs → 24.1 µs (marginal — gather overhead dominates) Dispatch: use grouped when bs / num_groups ≥ 8 (tiles at least half-packed). Applied to o_proj and down_proj (plain expand). QKV and gate/up still use their existing decode kernels (multi-slice handling not yet ported). Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Re-run tune_sweep with the updated decode kernels (hoisted masks, eviction_policy hints, tl.max_contiguous on k_offset from previous commit). Entry counts unchanged; configs are stable across the structural changes. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Profiling revealed the decode expand kernels are 100% instruction/overhead-
bound (0% memory bandwidth). Two config improvements discovered:
* BLOCK_N=128 (was 64): halves CTA count per segment, amortising per-CTA
fixed overhead without increasing register pressure.
* BLOCK_K=64 for rank≥64 (was 16): when BLOCK_K == rank the K-loop runs
exactly once, eliminating loop overhead and k-mask predicates entirely.
Speedups at n_segs=32 on H100:
plain expand rank= 64: 25.1 µs → 22.3 µs (1.12×)
plain expand rank=128: 33.9 µs → 29.3 µs (1.16×)
QKV expand rank= 64: 33.9 µs → 30.5 µs (1.11×)
gate/up rank= 64: 50.2 µs → 49.3 µs (1.02×)
Also adds BLOCK_K ∈ {64, 128} to the config search space in all three
expand kernels and fixes tune_sweep to clear the expand cache before
re-sweeping so it can discover configs outside the old BLOCK_K ∈ {16, 32}
space. profile_expand.py documents the profiling approach.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
…nels Using tl.multiple_of(K, BLOCK_K) tells the Triton compiler that K is exactly divisible by BLOCK_K — true for all our power-of-2 ranks and block sizes. This allows the compiler to prove that k_offset < k_rem is always True and eliminate the k-mask predicate from every load in the inner loop. The loop bound also simplifies from tl.cdiv(K, BLOCK_K) to the exact K // BLOCK_K, removing the ceil computation. Applied to all five decode kernels: lora_shrink, lora_shrink_prefill, lora_expand, lora_qkv_expand, lora_gate_up_expand. Speedups at n_segs=32, rank=64 on H100: shrink (K=4096): 18.0 µs → 14.8 µs (1.21×) expand (K=64): 22.3 µs → 14.4 µs (1.55×) QKV expand: 30.5 µs → 17.7 µs (1.73×) gate/up expand: 49.3 µs → 24.6 µs (2.01×) Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Add lora_expand_grouped_v2_fwd: adapts vLLM's token-sorted dispatch
pattern (grid axis-1 = num_active_adapters) to eliminate the
gather/scatter overhead of lora_expand_decode_fwd.
Key design:
• x and output accessed at scattered original token positions via
token_indices — no pre-gather or post-scatter needed
• Grid: (cdiv(M, BLOCK_S) × cdiv(N, BLOCK_N), num_groups)
— tiles both M and N, matching vLLM's parallelism structure
• CTAs beyond a group's token count exit immediately (same early-exit
as vLLM's lora_expand_kernel)
• Constexpr strides + tl.multiple_of EVEN_K from our prior work
Benchmarked vs vLLM inline + old grouped kernel (rank=64, N=4096, H100):
n= 32 n_unique=4: grpv2= 9.8µ vllm=11.3µ seg=22.2µ (+12% vs vllm)
n= 64 n_unique=4: grpv2= 10.4µ vllm=12.1µ seg=36.2µ (+14% vs vllm)
n=128 n_unique=4: grpv2= 12.7µ vllm=13.2µ seg=63.8µ (+ 4% vs vllm)
n=128 n_unique=1: grpv2= 11.0µ vllm=11.0µ seg=62.9µ (tied)
grpv2 wins in the common n_unique ≤ n/4 regime; vllm wins marginally
at extreme n_unique=n (all unique) corner cases, which the existing
dispatch threshold (bs // num_groups >= 8) already routes to segmented.
Replaces lora_expand_decode_fwd at both dispatch sites in lora_manager.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
…iscompute
When the autotuner benchmarks BLOCK_K=64 for MAX_RANK=16, the original
K // BLOCK_K = 0 caused zero loop iterations and a silent no-op (correct
base_output returned but LoRA delta omitted). The autotune then picked
this config as 'fastest' since it did nothing.
Fix: revert K // BLOCK_K -> tl.cdiv(K, BLOCK_K) and restore k_rem masks
so all BLOCK_K configs produce correct results. Configs with BLOCK_K > K
are now slower (one masked iteration) and the autotuner naturally avoids
them in favour of BLOCK_K <= rank configs.
Verified: 176/176 correctness checks pass across n in {1..128},
n_unique in {1..n}, rank in {16,32,64,128}, N in {4096,8192}.
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Signed-off-by: Qingyang Wu <willqywu@gmail.com> # Conflicts: # python/tokenspeed/runtime/models/qwen3.py
Summary of changes in this commit:
lora_expand_grouped_v2.py (correctness fix):
Restore tl.cdiv(K, BLOCK_K) + k-masks from K // BLOCK_K, preventing
the autotuner from selecting BLOCK_K > rank configs which silently
produced zero-delta outputs. Verified 176/176 correctness checks pass
across n ∈ {1..128}, n_unique ∈ {1..n}, rank ∈ {16,32,64,128},
N ∈ {4096,8192}.
lora_manager.py:
Switch o_proj and down_proj decode dispatch from lora_expand_decode_fwd
(gather/scatter) to lora_expand_grouped_v2_fwd (scattered reads, no copy).
Add adapter-group metadata (sort_order, group_slots, group_starts,
group_sizes, num_groups) to prepare_loras for the new kernel.
lora_expand.py / lora_qkv_expand.py / lora_gate_up_expand.py:
Add BLOCK_K ∈ {64, 128} to expand config spaces (profiling showed
0% BW utilisation — instruction-bound; BLOCK_K=64 eliminates the
K-loop for rank=64 when combined with tl.cdiv).
bench_vs_vllm.py, profile_expand.py:
Benchmark and profiling scripts comparing vs vLLM kernels.
End-to-end numbers (H100, rank=64):
Decode n=32 expand grpv2 vs original: 11.2 µs → was 25.1 µs (2.24×)
Decode n=128 expand grpv2 vs original: 14.2 µs → was 63.0 µs (4.45×)
Prefill s=512 QKV expand vs original: 28.8 µs → was 61.0 µs (2.12×)
Prefill s=512 shrink vs original: 16.7 µs → was 23.4 µs (1.40×)
Signed-off-by: Qingyang Wu <willqywu@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary (WIP)
End-to-end LoRA adapter serving for tokenspeed. Branch is not yet rebased on current main — many test files appear as deletions because the last merge from main predates several recent PRs (#18, #51, etc.). Will refresh before un-drafting.
What's in this PR
feat(lora): scaffold LoRA adapter serving infrastructure.lora_idthrough hybrid cache paths.lora_pathaccepted on/v1/completionsand/v1/chat/completions; propagated throughGenerateReqInput.__getitem__.--enable-loraworks without CUDA graphs.Status
This is an early draft — opening for visibility and review of the overall shape. Next steps before un-drafting:
main(resolve stale deletions of perf(eviction): O(k log N) eviction via persistent LRU set #18 / feat(deepseek-v4): add scheduler-managed sliding-window cache groups #51 test files).--enable-lora(currently only C++ unit testtest_lora_prefix_cache.cpp).lora_pathin the OpenAI-compat docs.Test plan
test_lora_prefix_cache.cpp.