Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f610236
feat(lora): scaffold LoRA adapter serving infrastructure
qywu May 7, 2026
a178735
feat(lora): option 2 — per-adapter prefix cache namespacing in C++ sc…
qywu May 7, 2026
043b051
fix(lora): thread lora_id through hybrid cache (HiCache) paths
qywu May 7, 2026
14e6bcc
feat(lora): LoraManager — GPU weight pool, LRU eviction, TP-aware app…
qywu May 7, 2026
31d31ee
fix(lora): eager-mode fixes for enable-lora
qywu May 7, 2026
cff906e
feat(lora): wire lora_path through HTTP /v1/completions and /v1/chat/…
qywu May 7, 2026
3df2b49
docs: add LoRA implementation HTML reference
qywu May 7, 2026
879ab71
docs: add tokenspeed codebase structure HTML reference
qywu May 7, 2026
a9083e3
fix(lora): evict namespace on adapter unload; remove no-op UpdateLeaves
qywu May 7, 2026
969c640
feat(lora): cuda-graph support + segment-grouped Triton kernels
qywu May 7, 2026
a6be351
Merge remote-tracking branch 'origin/main' into feat/lora-adapter-ser…
qywu May 8, 2026
0084ccc
perf(qwen3): drop pure-PyTorch RMSNorm fallback in qk_norm
qywu May 8, 2026
d482d91
perf(lora): capture no-LoRA graph variant for base-only batches
qywu May 8, 2026
4401d1b
feat(lora): MLP target support (gate_proj/up_proj/down_proj)
qywu May 8, 2026
1889725
fix(lora): propagate lora_path through GenerateReqInput.__getitem__
qywu May 8, 2026
fceda51
feat(lora): tiered GPU↔CPU↔disk pool with async prefetch
qywu May 8, 2026
fa93544
feat(lora): pack scheduling policy + cold/warm latency benchmark
qywu May 8, 2026
7ddabf5
Merge remote-tracking branch 'upstream/main' into feat/lora-adapter-s…
qywu May 11, 2026
a65f856
Merge remote-tracking branch 'upstream/main' into feat/lora-adapter-s…
qywu May 13, 2026
5a4c37a
fix(scheduler): repair eviction subtree path and reformat after merge
qywu May 13, 2026
2c1573c
Merge branch 'main' into feat/lora-adapter-serving
qywu May 13, 2026
548e6c0
Merge remote-tracking branch 'upstream/main' into feat/lora-adapter-s…
qywu May 14, 2026
23afa68
refactor(lora): move Triton LoRA kernels into tokenspeed-kernel
qywu May 14, 2026
5ffdee4
fix(lora): shard MLP buffers along TP and drop o_lora overcounting
qywu May 14, 2026
4f309c1
perf(lora): autotune the segment-grouped Triton kernels
qywu May 14, 2026
94a0fa3
refactor(lora): rename Triton kernel files to describe what they do
qywu May 14, 2026
0b17163
refactor(lora): move LoRA Triton kernels to ops/lora/triton/
qywu May 14, 2026
d6a4245
docs(lora): credit sglang/Punica in the Triton kernel docstrings
qywu May 14, 2026
18bf9dc
fix(lora): update import path to match kernel refactor
qywu May 18, 2026
ff4ae76
perf(lora): dispatch expand to chunked-SGMV for prefill
qywu May 18, 2026
5207f12
refactor(lora): rename chunked_sgmv_expand to lora_expand_prefill
qywu May 19, 2026
902b9e2
perf(lora): add lora_shrink_prefill and dispatch shrink on max_len
qywu May 19, 2026
9765279
build(lora): add comprehensive autotune sweep script
qywu May 19, 2026
4932fdc
perf(lora): populate autotune caches for common model shapes on H100
qywu May 19, 2026
a5834dd
perf(lora): kernel micro-optimisations in decode shrink/expand
qywu May 19, 2026
2cd20e4
perf(lora): grouped decode expand for tensor-core efficiency
qywu May 19, 2026
6cced6b
perf(lora): refresh autotune caches after decode kernel micro-opts
qywu May 19, 2026
f5fd737
perf(lora): expand configs BLOCK_N=128 + BLOCK_K=64 from profiling
qywu May 19, 2026
7c001ed
perf(lora): eliminate k-mask via tl.multiple_of across all decode ker…
qywu May 19, 2026
0630320
perf(lora): vLLM-style adapter-grouped expand without gather/scatter
qywu May 19, 2026
098a8cf
fix(lora): restore k-mask in grpv2 to prevent BLOCK_K > rank silent m…
qywu May 19, 2026
03f1759
Merge remote-tracking branch 'upstream/main' into pr-83-resolve
qywu May 19, 2026
dc5dfe7
perf(lora): adapter-grouped expand + correctness fix + benchmarks
qywu May 19, 2026
c008154
perf(lora): unify seg+grpv2 via max_group_size grid — no dispatch thr…
qywu May 19, 2026
7d8650a
refactor(lora): remove dead lora_expand_decode kernel
qywu May 20, 2026
bc60b53
Add MoE LoRA buffer tests and docs
qywu May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
817 changes: 817 additions & 0 deletions bench_chunked_sgmv.py

Large diffs are not rendered by default.

141 changes: 141 additions & 0 deletions bench_kernel_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Before/after benchmark for kernel micro-optimisations + sort-by-adapter.

Tests decode shrink and expand with mixed adapters — the scenario where
sort-by-adapter actually helps (adjacent CTAs share the same weight tile).

Usage:
python bench_kernel_opt.py
"""

from __future__ import annotations

import sys
from dataclasses import dataclass
from pathlib import Path

import torch
import triton

sys.path.insert(0, str(Path(__file__).parent / "tokenspeed-kernel" / "python"))

from tokenspeed_kernel.ops.lora.triton.lora_expand import lora_expand_fwd
from tokenspeed_kernel.ops.lora.triton.lora_expand_decode import lora_expand_decode_fwd
from tokenspeed_kernel.ops.lora.triton.lora_shrink import lora_shrink_fwd


@dataclass
class BatchInfo:
bs: int
max_len: int
num_segments: int
seg_lens: torch.Tensor
seg_indptr: torch.Tensor
weight_indices: torch.Tensor
lora_ranks: torch.Tensor
scalings: torch.Tensor
permutation: torch.Tensor | None = None
sort_order: torch.Tensor | None = None
group_slots: torch.Tensor | None = None
group_starts: torch.Tensor | None = None
group_sizes: torch.Tensor | None = None
num_groups: int = 0


def make_mixed_batch(
n_segs: int,
n_unique_adapters: int,
rank: int,
device: str = "cuda",
) -> BatchInfo:
"""n_segs decode segments, round-robin across n_unique_adapters adapters."""
slots_list = [(i % n_unique_adapters) + 1 for i in range(n_segs)]
slots = torch.tensor(slots_list, dtype=torch.int32, device=device)

seg_lens = torch.ones(n_segs, dtype=torch.int32, device=device)
seg_indptr = torch.arange(n_segs + 1, dtype=torch.int32, device=device)
n_slots = n_unique_adapters + 1
lora_ranks = torch.zeros(n_slots, dtype=torch.int32, device=device)
lora_ranks[1:] = rank
scalings = torch.ones(n_slots, dtype=torch.float32, device=device)
scalings[0] = 0.0

# Build group metadata (same logic as prepare_loras)
sort_order_cpu = sorted(range(n_segs), key=lambda i: slots_list[i])
groups: list[list[int]] = []
for pos, orig in enumerate(sort_order_cpu):
slot = slots_list[orig]
if not groups or groups[-1][0] != slot:
groups.append([slot, pos, 1])
else:
groups[-1][2] += 1
ng = len(groups)
sort_order_gpu = torch.tensor(sort_order_cpu, dtype=torch.int64, device=device)
group_slots_gpu = torch.tensor(
[g[0] for g in groups], dtype=torch.int32, device=device
)
group_starts_gpu = torch.tensor(
[g[1] for g in groups], dtype=torch.int32, device=device
)
group_sizes_gpu = torch.tensor(
[g[2] for g in groups], dtype=torch.int32, device=device
)

return BatchInfo(
bs=n_segs,
max_len=1,
num_segments=n_segs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
weight_indices=slots,
lora_ranks=lora_ranks,
scalings=scalings,
sort_order=sort_order_gpu,
group_slots=group_slots_gpu,
group_starts=group_starts_gpu,
group_sizes=group_sizes_gpu,
num_groups=ng,
)


def bench(fn, warmup=25, rep=200):
return triton.testing.do_bench(fn, warmup=warmup, rep=rep) * 1000


def run(n_segs: int, n_unique: int, rank: int, hidden: int) -> None:
dev, dt = "cuda", torch.bfloat16
n_slots = n_unique + 1
s = n_segs

bi = make_mixed_batch(n_segs, n_unique, rank, device=dev)

x_ex = torch.randn((s, rank), device=dev, dtype=dt)
w_ex = torch.randn((n_slots, hidden, rank), device=dev, dtype=dt)
o_ex = torch.zeros((s, hidden), device=dev, dtype=dt)

t_base = bench(lambda: lora_expand_fwd(x_ex, w_ex, bi, base_output=o_ex.clone()))
t_grouped = bench(
lambda: lora_expand_decode_fwd(x_ex, w_ex, bi, base_output=o_ex.clone())
)

print(
f"n_segs={n_segs:>3} n_unique={n_unique:>2} rank={rank:>3} hidden={hidden:>5} |"
f" base={t_base:>6.1f}µ grouped={t_grouped:>6.1f}µ {t_base/t_grouped:>5.2f}x"
)


if __name__ == "__main__":
# Qwen3-8B TP=2
HIDDEN, RANK = 4096, 64

print(
f"\n{'n_segs':>7} {'n_unique':>9} {'rank':>5} {'hidden':>7} | {'base':>8} {'grouped':>9} speedup"
)
print("-" * 75)
for n_unique in (1, 2, 4, 8, 16, 32):
run(n_segs=32, n_unique=n_unique, rank=RANK, hidden=HIDDEN)
print()
for n_segs in (8, 16, 32, 64, 128):
run(n_segs=n_segs, n_unique=4, rank=RANK, hidden=HIDDEN)
print()
for rank in (16, 32, 64, 128):
run(n_segs=32, n_unique=4, rank=rank, hidden=HIDDEN)
Loading
Loading