From 48e63ba6b0edc8690e50cbb126489f92f47b112e Mon Sep 17 00:00:00 2001 From: yihonglie Date: Sat, 30 May 2026 12:03:02 -0500 Subject: [PATCH 01/27] [kv_transfer] ATOM standalone LMCache CPU/NVMe KV-offload connector (WIP) New atom/kv_transfer/offload/ package: a standalone (non-vLLM) KV offload connector that reuses LMCache as a storage tier (CPU LRU + NVMe L3) via StorageManager + ChunkedTokenDatabase, with an opaque per-block byte codec (ATOMKVByteCodec) that bypasses engine.store/retrieve to preserve AITER's swizzled KV layout. Load/save run daemon-after-forward off the RPC thread; cross-process hit lookup via LMCache's ZMQ LookupClient/LookupServer. Engine hooks: - disaggregation/factory.py: register the "lmcache_offload" connector. - model_engine/scheduler.py: offload-wake branch (a parked WAITING_FOR_REMOTE_KVS seq resumes as a suffix prefill, not the P/D decode-jump) + offload_resume guard against re-allocate-on-populated-blocks. TP=2 fixes (offload was non-functional at TP>1): - config.py: lookup_server_worker_ids=[0]. The cross-rank ZMQ lookup took min() across ranks and rank!=0 returned 0 despite having stored the chunk (contains()=True) -> min(0,hit)=0 -> the load never fired. Only rank 0 answers lookup now (both ranks save in lockstep, so rank 0 is authoritative). - connector.py: split load/save into separate executors so a latency-critical reload never queues behind the fire-and-forget save backlog. Verified end-to-end at TP=2 (MiniMax-M2.5 FP8, 2x MI325X): an evicted 32K prompt reloads 32000 tokens from CPU and recomputes only the 5-token suffix (Scheduled prefill cached:[32000], new:[5]); was a full recompute before. Known issues (WIP): - Reload latency high (~131s in the micro-bench): the per-block Python copy path is slow (~88ms/chunk) and the load waits on the storage_manager lock held by a save burst. Needs a bulk/batched copy rewrite of ATOMKVByteCodec. See ../OFFLOAD_TP2_FIXES.md and ../PHASE4_RESULTS.md. - Verbose [OFFLOAD-*] diagnostic logging + an engine.lookup monkeypatch are still present and must be removed/demoted before benchmark/production use. Co-Authored-By: Claude Opus 4.8 (1M context) --- atom/kv_transfer/disaggregation/factory.py | 9 + atom/kv_transfer/offload/__init__.py | 20 + atom/kv_transfer/offload/config.py | 95 +++++ atom/kv_transfer/offload/connector.py | 466 +++++++++++++++++++++ atom/kv_transfer/offload/gpu_connector.py | 119 ++++++ atom/kv_transfer/offload/metadata.py | 78 ++++ atom/model_engine/scheduler.py | 114 +++-- 7 files changed, 870 insertions(+), 31 deletions(-) create mode 100644 atom/kv_transfer/offload/__init__.py create mode 100644 atom/kv_transfer/offload/config.py create mode 100644 atom/kv_transfer/offload/connector.py create mode 100644 atom/kv_transfer/offload/gpu_connector.py create mode 100644 atom/kv_transfer/offload/metadata.py diff --git a/atom/kv_transfer/disaggregation/factory.py b/atom/kv_transfer/disaggregation/factory.py index d18d621c40..11b795636b 100644 --- a/atom/kv_transfer/disaggregation/factory.py +++ b/atom/kv_transfer/disaggregation/factory.py @@ -134,3 +134,12 @@ def create_connector( scheduler_module="atom.kv_transfer.disaggregation.mooncake.mooncake_connector", scheduler_class="MooncakeConnectorScheduler", ) + + +# ATOM standalone CPU/NVMe KV offload backend (registers "lmcache_offload"). +# Import is lightweight (offload/__init__ only records module paths as strings; +# the connector module is imported lazily by create_connector when selected). +try: + import atom.kv_transfer.offload # noqa: F401,E402 +except Exception as _e: # pragma: no cover - offload optional (needs lmcache) + logger.debug("lmcache_offload backend not registered: %s", _e) diff --git a/atom/kv_transfer/offload/__init__.py b/atom/kv_transfer/offload/__init__.py new file mode 100644 index 0000000000..b9a9edf45f --- /dev/null +++ b/atom/kv_transfer/offload/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""ATOM standalone LMCache CPU/NVMe KV-offload connector. + +Registers the ``lmcache_offload`` backend with the shared KV connector factory. +Enable via ``--kv-transfer-config '{"kv_connector":"lmcache_offload","kv_role":"offload"}'`` +plus LMCache env (``LMCACHE_LOCAL_CPU=True``, ``LMCACHE_MAX_LOCAL_CPU_SIZE``, +``LMCACHE_CHUNK_SIZE=256``, optional ``LMCACHE_LOCAL_DISK`` for the NVMe L3 tier). +""" + +from atom.kv_transfer.disaggregation.factory import KVConnectorFactory + +KVConnectorFactory.register( + "lmcache_offload", + worker_module="atom.kv_transfer.offload.connector", + worker_class="LMCacheOffloadConnector", + scheduler_module="atom.kv_transfer.offload.connector", + scheduler_class="LMCacheOffloadConnectorScheduler", +) diff --git a/atom/kv_transfer/offload/config.py b/atom/kv_transfer/offload/config.py new file mode 100644 index 0000000000..83fe7c2cf1 --- /dev/null +++ b/atom/kv_transfer/offload/config.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Build the per-rank ``LMCacheEngineConfig`` + ``LMCacheMetadata`` for the +ATOM standalone offload connector. + +LMCache is driven by ``LMCACHE_*`` env vars (``LMCACHE_LOCAL_CPU``, +``LMCACHE_MAX_LOCAL_CPU_SIZE``, ``LMCACHE_CHUNK_SIZE``, ``LMCACHE_LOCAL_DISK``, +``LMCACHE_MAX_LOCAL_DISK_SIZE`` …) exactly like the vLLM recipe. We additionally +allow overrides via ``kv_transfer_config`` extras keyed ``lmcache.`` and +force ``use_gds=False`` (cufile GDS init hangs without NVMe-GDS hardware). +""" + +from __future__ import annotations + +import os +from typing import Any + +import torch + + +def build_lmcache_config(): + """Return an ``LMCacheEngineConfig`` from ``LMCACHE_*`` env + extras.""" + from lmcache.v1.config import LMCacheEngineConfig + + cfg = LMCacheEngineConfig.from_env() + # cufile GDS has no NVMe-GDS hardware here and hangs on init; force off. + if getattr(cfg, "use_gds", False): + try: + cfg.use_gds = False + except Exception: + pass + # TP>1 fix: only rank 0 serves/answers the ZMQ lookup. Without this the + # client queries all ranks and takes min() over results; we observed rank!=0 + # engine.lookup returning 0 even though that rank stored the chunk + # (contains()=True) -> min(0, hit)=0 -> the scheduler never sees the hit and + # always recomputes. Our connector saves on ALL ranks in lockstep, so rank 0 + # is authoritative for "is it offloaded?"; each rank still loads its own KV + # shard, and _do_load is all-or-nothing (re-prefills if a shard is missing). + try: + cfg.lookup_server_worker_ids = [0] + except Exception: + pass + return cfg + + +def apply_extra_overrides(cfg, kv_transfer_config: dict[str, Any] | None) -> None: + """Apply ``{"lmcache.": value}`` extras from kv_transfer_config.""" + if not kv_transfer_config: + return + extra = kv_transfer_config.get("kv_connector_extra_config", kv_transfer_config) + for key, value in (extra or {}).items(): + if isinstance(key, str) and key.startswith("lmcache."): + field = key[len("lmcache.") :] + if hasattr(cfg, field): + try: + setattr(cfg, field, value) + except Exception: + pass + + +def build_lmcache_metadata(config, cfg, world_size: int, worker_id: int): + """Build ``LMCacheMetadata`` for this rank from ATOM ``config`` + LMCache cfg. + + ``kv_shape`` follows LMCache's ``(num_layers, 2, chunk_size, num_kv_heads, + head_dim)`` convention. For our opaque BINARY-style storage the exact dims + are only used for key/shape bookkeeping (we override the byte layout in the + codec), but we fill them faithfully from hf_config so logging/keys are sane. + """ + from aiter import dtypes + from lmcache.v1.metadata import LMCacheMetadata + + hf = config.hf_config + num_layers = int(getattr(hf, "num_hidden_layers")) + num_kv_heads = int(getattr(hf, "num_key_value_heads", getattr(hf, "num_attention_heads"))) + tp = int(getattr(config, "tensor_parallel_size", world_size) or 1) + num_kv_heads_local = max(1, num_kv_heads // tp) + head_dim = int(getattr(hf, "head_dim", 0) or (hf.hidden_size // hf.num_attention_heads)) + kv_dtype = dtypes.d_dtypes[config.kv_cache_dtype] + model_name = str(getattr(config, "model", "atom-model")) + + return LMCacheMetadata( + model_name=model_name, + world_size=world_size, + local_world_size=world_size, + worker_id=worker_id, + local_worker_id=worker_id, + kv_dtype=kv_dtype, + kv_shape=(num_layers, 2, int(cfg.chunk_size), num_kv_heads_local, head_dim), + use_mla=False, + chunk_size=int(cfg.chunk_size), + # Shared id so the scheduler's ZMQ LookupClient and each worker's + # LookupServer derive the SAME ipc socket path (get_zmq_rpc_path_lmcache). + engine_id="atom-offload", + ) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py new file mode 100644 index 0000000000..2cc34ca7cf --- /dev/null +++ b/atom/kv_transfer/offload/connector.py @@ -0,0 +1,466 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""ATOM standalone LMCache CPU/NVMe KV-offload connector. + +Design (see ../../../../PLAN_impl_lmcache_offload_v5.md + 005 LEARN notes): + +* **Reuse real LMCache as a storage tier only** — per-rank ``LMCacheEngine`` for its + ``StorageManager`` (CPU LRU + NVMe L3) + ``ChunkedTokenDatabase`` (chunk-256 keys). + We bypass ``engine.store/retrieve`` (token-major GPU path can't represent AITER's + swizzle) and instead move **opaque per-block bytes** via :class:`ATOMKVByteCodec` + into pinned ``KV_2LTD``-as-uint8 ``MemoryObj``s. +* **Daemon-after-forward copies** — ``start_load_kv`` only ``submit``s to a single + serial copy daemon (ThreadPoolExecutor max_workers=1) and returns immediately, so + the worker RPC thread is free for ``forward``; completions are polled in + ``get_finished`` (called post-forward by ``async_proc_aggregation``). This is the + fix for 005's "load blocks/starves prefill" (corr(TTFT, prefill-conc)=0.773). +* **Cross-process hit lookup** — scheduler (EngineCore process) queries worker hits + via LMCache's ZMQ ``LookupClient``/``LookupServer`` (no homegrown mirror). +""" + +from __future__ import annotations + +import logging +import os +import threading +from concurrent.futures import ThreadPoolExecutor + +import torch + +from atom.kv_transfer.disaggregation.base import ( + KVConnectorBase, + KVConnectorSchedulerBase, +) +from atom.kv_transfer.offload import config as offcfg +from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec +from atom.kv_transfer.offload.metadata import ( + LMCacheOffloadMetadata, + LMCacheReqMeta, + LoadSpec, + SaveSpec, +) + +logger = logging.getLogger("atom") + + +def _cdiv(a: int, b: int) -> int: + return -(-a // b) + + +class _UnusedGPUConnector: + """Satisfies LMCacheEngineBuilder.get_or_create; never invoked (we do our own + byte-copy and never call engine.store/retrieve).""" + + def to_gpu(self, *a, **k): + raise NotImplementedError + + def from_gpu(self, *a, **k): + raise NotImplementedError + + def batched_from_gpu(self, *a, **k): + raise NotImplementedError + + def batched_to_gpu(self, *a, **k): + raise NotImplementedError + + def get_shape(self, num_tokens): + return torch.Size((num_tokens,)) + + +# ===================================================================== +# Worker side +# ===================================================================== +class LMCacheOffloadConnector(KVConnectorBase): + # Offload is a *consumer* from the scheduler's POV (it loads KV back). Saves + # are fire-and-forget on the worker and must NOT be reported as + # finished_sending (the scheduler frees blocks on finished_sending — a P/D + # producer semantic that would wrongly deallocate live offload blocks). + is_producer = False + + def __init__(self, config) -> None: + self._config = config + kvc = getattr(config, "kv_transfer_config", {}) or {} + self.kv_role = kvc.get("kv_role", "offload") + self._do_save = self.kv_role in ("offload", "kv_both", "kv_producer") + self._do_load = self.kv_role in ("offload", "kv_both", "kv_consumer") + self.block_size = int(config.kv_cache_block_size) + self.chunk_size: int | None = None + + # Copy daemons: keep GPU<->host copies off the RPC thread. SEPARATE + # executors for LOAD vs SAVE so a load (on the TTFT critical path — a + # parked seq is waiting for it) never queues behind a backlog of fire- + # and-forget saves (Phase 4 root cause: with one shared serial daemon, a + # reload sat behind ~N filler saves -> request hung well past timeout). + # Each worker thread gets its OWN CUDA stream (disjoint block_ids -> no + # write conflict). OFFLOAD_COPY_WORKERS tunes the SAVE pool only. + n_save_workers = int(os.environ.get("OFFLOAD_COPY_WORKERS", "1")) + self._load_executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="lmc-offload-load" + ) + self._save_executor = ThreadPoolExecutor( + max_workers=n_save_workers, thread_name_prefix="lmc-offload-save" + ) + self._tls = threading.local() # per-thread copy stream + self._lock = threading.Lock() + self._done_load: set[str] = set() + self._done_save: set[str] = set() + + self._engine = None + self._sm = None + self._tdb = None + self._codec: ATOMKVByteCodec | None = None + self._lookup_server = None + + # -- lifecycle -------------------------------------------------------- + def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: + from aiter.dist.parallel_state import get_tp_group + from lmcache.v1.cache_engine import LMCacheEngineBuilder + + tp = get_tp_group() + rank, world = tp.rank_in_group, tp.world_size + self._rank = rank + + cfg = offcfg.build_lmcache_config() + offcfg.apply_extra_overrides(cfg, getattr(self._config, "kv_transfer_config", None)) + meta = offcfg.build_lmcache_metadata(self._config, cfg, world, rank) + self.chunk_size = int(cfg.chunk_size) + + self._engine = LMCacheEngineBuilder.get_or_create( + f"atom-offload-{rank}", cfg, meta, _UnusedGPUConnector(), + lambda t, s: None, lambda o, s: o, + ) + self._engine.post_init() + self._sm = self._engine.storage_manager + self._tdb = self._engine.token_database + self._codec = ATOMKVByteCodec(kv_caches) + + # DEBUG: wrap engine.lookup to capture EVERY call (incl. the ones the ZMQ + # lookup_server makes on behalf of the scheduler) — args + result. + _orig_lookup = self._engine.lookup + _rk = rank + def _logged_lookup(*a, **k): + r = _orig_lookup(*a, **k) + h = k.get("hashes") + logger.info("[ENGINE.LOOKUP] rank=%s lookup_id=%s nhashes=%s first3=%s -> %s", + _rk, k.get("lookup_id"), (len(h) if h is not None else None), + (list(h[:3]) if h else None), r) + return r + self._engine.lookup = _logged_lookup + + # ZMQ lookup server so the scheduler process can query our hit counts. + try: + from lmcache.v1.lookup_client.factory import LookupClientFactory + self._lookup_server = LookupClientFactory.create_lookup_server( + self._engine, meta + ) + except Exception as e: # lookup server optional for save-only smoke + logger.warning("LMCache offload: lookup server not started: %s", e) + + logger.info( + "LMCache offload worker rank=%d: bytes_per_block=%d chunk=%d save=%s load=%s", + rank, self._codec.bytes_per_block, self.chunk_size, + self._do_save, self._do_load, + ) + + # -- per-step (RPC thread): only enqueue, never copy ------------------ + def start_load_kv(self, metadata) -> None: + if not isinstance(metadata, LMCacheOffloadMetadata): + return + for req in metadata.requests: + if req.load_spec is not None and self._do_load: + self._load_executor.submit(self._guard, self._do_load_req, req) + if req.save_spec is not None and self._do_save: + self._save_executor.submit(self._guard, self._do_save_req, req) + + def _guard(self, fn, req) -> None: + try: + fn(req) + except Exception: + logger.exception("LMCache offload: %s failed for %s", fn.__name__, req.req_id) + # Wake the seq anyway so it is not stuck parked; scheduler re-derives + # how much is actually cached (load) / proceeds (save). + with self._lock: + (self._done_load if fn is self._do_load_req else self._done_save).add( + req.req_id + ) + + def _stream(self) -> torch.cuda.Stream: + """A CUDA stream owned by the calling copy-daemon thread (lazily made).""" + s = getattr(self._tls, "stream", None) + if s is None: + s = torch.cuda.Stream() + self._tls.stream = s + return s + + def _block_ids(self, req: LMCacheReqMeta, start: int, end: int) -> list[int]: + return req.block_ids[start // self.block_size : _cdiv(end, self.block_size)] + + # -- copy daemon thread ---------------------------------------------- + def _do_load_req(self, req: LMCacheReqMeta) -> None: + ls = req.load_spec + assert ls is not None + stream = self._stream() + hbm = (ls.hbm_cached_tokens // self.chunk_size) * self.chunk_size + toks = req.token_ids[: ls.lmcache_cached_tokens] + mask = torch.ones(len(toks), dtype=torch.bool) + mask[:hbm] = False + chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) + logger.debug("offload _do_load req=%s hbm=%d lmc=%d chunks=%d", + req.req_id, hbm, ls.lmcache_cached_tokens, len(chunks)) + + # All-or-nothing: a partial load would let attention read uninitialized + # blocks. If any chunk is gone (evicted between lookup and load), skip the + # whole load — the seq wakes and re-prefills the suffix (loaded 0). + for (_s, _e, key) in chunks: + if not self._sm.contains(key): + logger.warning("LMCache offload: load miss req=%s; re-prefill", req.req_id) + with self._lock: + self._done_load.add(req.req_id) + return + + for (s, e, key) in chunks: + mo = self._sm.get(key) + if mo is None: + with self._lock: + self._done_load.add(req.req_id) + return + self._codec.host_to_gpu(mo.tensor, self._block_ids(req, s, e), stream) + mo.ref_count_down() + stream.synchronize() + # Release the lookup pin (taken by the scheduler's LookupClient.lookup) + # now that the chunks are safely in GPU; lets the pool evict them later. + try: + self._engine.lookup_unpin([str(req.req_id)]) # LMCache pin keyed by str id + except Exception: + pass + with self._lock: + self._done_load.add(req.req_id) + logger.info("offload _do_load DONE req=%s", req.req_id) + + def _do_save_req(self, req: LMCacheReqMeta) -> None: + from lmcache.v1.memory_management import MemoryFormat + + ss = req.save_spec + assert ss is not None + stream = self._stream() + toks = req.token_ids + if not req.is_last_prefill: + toks = toks[: (len(toks) // self.chunk_size) * self.chunk_size] + skip = (ss.skip_leading_tokens // self.chunk_size) * self.chunk_size + if skip >= len(toks): + with self._lock: + self._done_save.add(req.req_id) + return + + mask = torch.ones(len(toks), dtype=torch.bool) + mask[:skip] = False + chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) + + keys, objs, already = [], [], 0 + for (s, e, key) in chunks: + if self._sm.contains(key): # already offloaded → skip wasted D2H + already += 1 + continue + bids = self._block_ids(req, s, e) + nbytes = len(bids) * self._codec.bytes_per_block + mo = self._sm.allocate(torch.Size((nbytes,)), torch.uint8, + fmt=MemoryFormat.KV_2LTD) + if mo is None: # pool under pressure; stop here + break + # D2H on this thread's dedicated copy stream (off the compute stream). + self._codec.gpu_to_host(mo.tensor, bids, stream) + keys.append(key) + objs.append(mo) + + if keys: + stream.synchronize() # stream-specific + self._sm.batched_put(keys, objs) + with self._lock: + self._done_save.add(req.req_id) + _kh = [getattr(k, "chunk_hash", None) for k in keys[:2]] + _contains = [bool(self._sm.contains(k)) for k in keys[:2]] + logger.info("[OFFLOAD-SAVE] rank=%s req=%s toks=%d chunks=%d stored=%d already=%d " + "chunkhash2=%s contains=%s", + self._rank, req.req_id, len(toks), len(chunks), len(keys), + already, _kh, _contains) + + # -- per-step (RPC thread, post-forward): poll completions ------------ + def get_finished(self) -> tuple[set, set]: + # (finished_sending, finished_recving). Offload SAVES are fire-and-forget + # (they don't free blocks), so finished_sending is ALWAYS empty; only + # completed LOADS are reported, to wake parked seqs for suffix prefill. + with self._lock: + dl = set(self._done_load) + self._done_save.clear() + self._done_load.clear() + return set(), dl + + def get_finished_recv_blocks(self) -> list[int]: + # Local CUDA copies are ordered by the copy stream + synchronize() before + # we mark done; no RDMA-style GPU fence needed. + return [] + + +# ===================================================================== +# Scheduler side +# ===================================================================== +class LMCacheOffloadConnectorScheduler(KVConnectorSchedulerBase): + # Consumer semantics: finished_recving wakes parked seqs (the engine asserts + # `not is_producer` on that path). Offload never uses finished_sending. + is_producer = False + # Opt the scheduler into offload-wake (suffix prefill) instead of the P/D + # decode-jump in Scheduler.schedule(); see Scheduler._is_offload_connector. + is_offload = True + + def __init__(self, config) -> None: + self._config = config + kvc = getattr(config, "kv_transfer_config", {}) or {} + self.kv_role = kvc.get("kv_role", "offload") + self.block_size = int(config.kv_cache_block_size) + self.chunk_size: int | None = None + self._lookup_client = None + + # req_id -> LoadSpec (pending load decided at match time) + self._load_specs: dict[str, LoadSpec] = {} + # req_id -> Sequence (queued to recv this step) + self._reqs_need_recv: dict[str, object] = {} + # Persistent save tracker: sid -> [seq, saved_offset]. A seq's prompt + # prefix is stored to LMCache once prefill computes it + # (seq.prefix_hashes_published flips True), chunk by chunk. + self._save_tracker: dict[str, list] = {} + self._lookup_in_step: list[str] = [] + + try: + cfg = offcfg.build_lmcache_config() + offcfg.apply_extra_overrides(cfg, kvc) + self.chunk_size = int(cfg.chunk_size) + from lmcache.v1.lookup_client.factory import LookupClientFactory + world = int(getattr(config, "tensor_parallel_size", 1) or 1) + meta = offcfg.build_lmcache_metadata(config, cfg, world, 0) + self._lookup_client = LookupClientFactory.create_lookup_client(cfg, meta) + except Exception as e: + logger.warning("LMCache offload scheduler: lookup client unavailable: %s", e) + + # -- match: how many extra tokens can come from CPU/NVMe ------------- + def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: + if self._lookup_client is None: + return 0, False + num_prompt = seq.num_prompt_tokens + token_ids = list(seq.token_ids[:num_prompt]) + try: + hit = self._lookup_client.lookup(token_ids, lookup_id=str(seq.id)) + except Exception: + logger.exception("LMCache offload lookup failed for seq %s", seq.id) + return 0, False + _lh = None + try: + tdb = getattr(self._lookup_client, "token_database", None) + if tdb is not None: + _lh = [k for (_s, _e, k) in list( + tdb.process_tokens(token_ids, make_key=False))[:3]] + except Exception as e: + _lh = f"err:{e}" + logger.info("[OFFLOAD-LOOKUP] seq=%s num_prompt=%d hbm_cached=%d hit=%s lookuphash3=%s", + seq.id, num_prompt, int(seq.num_cached_tokens), hit, _lh) + if not hit: + return 0, False + self._lookup_in_step.append(str(seq.id)) + need = int(hit) - int(seq.num_cached_tokens) + if int(hit) == num_prompt: # full-prompt hit → recompute last token + need -= 1 + if need <= 0: + return 0, False + self._load_specs[str(seq.id)] = LoadSpec( + hbm_cached_tokens=int(seq.num_cached_tokens), + lmcache_cached_tokens=int(hit), + can_load=False, + ) + return need, True # True => park in WAITING_FOR_REMOTE_KVS + + def update_state_after_alloc(self, seq) -> None: + sid = str(seq.id) + ls = self._load_specs.get(sid) + logger.info("[OFFLOAD-ALLOC] seq=%s ls_found=%s num_cached_now=%s", + seq.id, ls is not None, int(getattr(seq, "num_cached_tokens", -1))) + if ls is not None: + ls.can_load = True + self._reqs_need_recv[sid] = seq + # Track for save; the prompt prefix is offloaded later, once prefill has + # actually computed it (checked via prefix_hashes_published in build). + if sid not in self._save_tracker: + self._save_tracker[sid] = [seq, 0] + + def build_connector_meta(self) -> LMCacheOffloadMetadata: + meta = LMCacheOffloadMetadata() + meta.lookup_requests_in_step = self._lookup_in_step + self._lookup_in_step = [] + + # Loads + logger.info("[OFFLOAD-BUILD] reqs_need_recv=%d", len(self._reqs_need_recv)) + for sid, seq in self._reqs_need_recv.items(): + ls = self._load_specs.pop(sid, None) + if ls is None or not ls.can_load: + logger.info("[OFFLOAD-LOAD-SKIP] seq=%s ls=%s can_load=%s", + sid, ls is not None, getattr(ls, "can_load", None)) + continue + # ★ Use the REAL HBM-cached count as the load floor. + # get_num_new_matched_tokens runs BEFORE the prefix-cache match in + # block_manager.allocate, so seq.num_cached_tokens was stale (often + # 0) when the LoadSpec was recorded. By now (post-allocate) it is the + # true HBM hit. Loading below this floor would overwrite HBM + # prefix-cache blocks (possibly shared with other seqs) -> output + # corruption. So load only [hbm_cached, offload_hit). + ls.hbm_cached_tokens = int(seq.num_cached_tokens) + # num_cached after load = max(HBM, offload); never drop below HBM. + seq.offload_loaded_tokens = max( + int(seq.num_cached_tokens), int(ls.lmcache_cached_tokens) + ) + # req_id MUST be the raw seq.id (the type the scheduler compares + # against in _update_waiting_for_remote_kv); str(seq.id) is only for + # LMCache's lookup/pin API. A str here silently never wakes the seq. + logger.info("[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d offload_loaded=%d nblocks=%d", + seq.id, ls.hbm_cached_tokens, ls.lmcache_cached_tokens, + seq.offload_loaded_tokens, len(list(seq.block_table))) + meta.add_request(LMCacheReqMeta( + req_id=seq.id, + token_ids=list(seq.token_ids[: ls.lmcache_cached_tokens]), + block_ids=list(seq.block_table), + load_spec=ls, + )) + # Saves: store the prompt prefix once prefill has computed it. We detect + # "computed" via seq.prefix_hashes_published (set in postprocess after the + # prefill step), so the blocks we D2H are already written -- no race with + # forward. Persistent tracker: each chunk is stored once. + chunk = self.chunk_size or 256 + for sid, entry in self._save_tracker.items(): + seq, saved = entry + if sid in self._reqs_need_recv: + continue # loading this step; defer its save + if not getattr(seq, "prefix_hashes_published", False): + continue # prefill not finished computing the prompt yet + aligned = (int(seq.num_prompt_tokens) // chunk) * chunk + if aligned <= saved: + continue + logger.info("[OFFLOAD-SAVE-EMIT] seq=%s num_prompt=%d aligned=%d saved=%d", + seq.id, int(seq.num_prompt_tokens), aligned, saved) + meta.add_request(LMCacheReqMeta( + req_id=seq.id, + token_ids=list(seq.token_ids[:aligned]), + block_ids=list(seq.block_table), + save_spec=SaveSpec(skip_leading_tokens=saved, can_save=True), + )) + entry[1] = aligned + self._reqs_need_recv.clear() + return meta + + def request_finished(self, seq) -> None: + sid = str(seq.id) + self._load_specs.pop(sid, None) + self._reqs_need_recv.pop(sid, None) + self._save_tracker.pop(sid, None) + if self._lookup_client is not None: + try: + self._lookup_client.clear_lookup_status(sid) + except Exception: + pass diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py new file mode 100644 index 0000000000..01e6c68af2 --- /dev/null +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""AITER-layout-aware byte codec between ATOM's paged GPU KV cache and a flat +pinned host buffer (an LMCache ``MemoryObj``'s ``uint8`` tensor). + +Why a byte codec instead of an LMCache ``GPUConnectorInterface`` subclass: +LMCache's ``engine.store/retrieve`` GPU path only emits token-major formats +(``KV_2LTD`` etc.) via ``normalize_kv_and_discover_format``, which rejects +AITER's swizzled K layout ``(nb, H, D//x, bs, x)`` and strided V ``(nb, H, D, bs)``. +We therefore bypass that path: we store **opaque per-block bytes** (byte-identical +round-trip — the attention kernel reads back its own layout) and drive LMCache only +as a storage tier (``StorageManager`` + ``ChunkedTokenDatabase``). + +A whole *block* of any per-layer cache tensor (``t[block_id]``) is contiguous, so a +block's KV is a set of contiguous byte slices: per layer K, V, and (fp8) k_scale, +v_scale. The flat per-block layout in the host buffer is:: + + [ L0.K | L0.V | L0.kS | L0.vS | L1.K | L1.V | ... ] (only present segments) + +which is self-consistent for store and load (we never reinterpret it). +""" + +from __future__ import annotations + +import torch + + +class ATOMKVByteCodec: + """Per-block byte mover between paged GPU KV tensors and a flat host buffer.""" + + def __init__(self, kv_caches: dict) -> None: + """``kv_caches``: ordered ``{layer_name: KVCacheTensor}`` from + ``register_kv_caches``. We flatten every movable per-layer tensor (K, V, + and fp8 scales when present) into one ordered segment list. Each segment + is a GPU tensor shaped ``[num_blocks, ...]``; segment[block_id] is a + contiguous block slice we copy as raw bytes.""" + self._segments: list[torch.Tensor] = [] + for _name, kvt in kv_caches.items(): + for t in ( + getattr(kvt, "k_cache", None), + getattr(kvt, "v_cache", None), + getattr(kvt, "k_scale", None), + getattr(kvt, "v_scale", None), + ): + if t is not None and isinstance(t, torch.Tensor) and t.numel() > 0: + self._segments.append(t) + + if not self._segments: + raise ValueError("ATOMKVByteCodec: no movable KV tensors registered") + + # Bytes for one block of each segment (block is dim 0). + self._seg_block_bytes: list[int] = [ + int(t[0].numel()) * t.element_size() for t in self._segments + ] + # Byte offset of each segment within one block's flat record. + self._seg_off: list[int] = [] + acc = 0 + for nb in self._seg_block_bytes: + self._seg_off.append(acc) + acc += nb + self.bytes_per_block: int = acc + self.num_blocks: int = int(self._segments[0].shape[0]) + + # -- helpers ---------------------------------------------------------- + @staticmethod + def _block_bytes_view(seg: torch.Tensor, block_id: int) -> torch.Tensor: + """Flat ``uint8`` view of one contiguous block slice (no copy).""" + blk = seg[block_id] + if not blk.is_contiguous(): + # Block slices of the paged cache are contiguous in practice; guard + # anyway. A non-contiguous block would break in-place H2D, so fail loud. + raise RuntimeError("ATOMKVByteCodec: block slice not contiguous") + return blk.reshape(-1).view(torch.uint8) + + # -- public API ------------------------------------------------------- + def gpu_to_host( + self, + host_buf: torch.Tensor, + block_ids: list[int], + stream: torch.cuda.Stream | None = None, + ) -> None: + """D2H: gather ``block_ids`` from the paged GPU cache into the flat + pinned ``host_buf`` (uint8, length == len(block_ids) * bytes_per_block).""" + ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with ctx: + for i, bid in enumerate(block_ids): + base = i * self.bytes_per_block + for seg, off, nb in zip( + self._segments, self._seg_off, self._seg_block_bytes + ): + src = self._block_bytes_view(seg, bid) + host_buf[base + off : base + off + nb].copy_(src, non_blocking=True) + + def host_to_gpu( + self, + host_buf: torch.Tensor, + block_ids: list[int], + stream: torch.cuda.Stream | None = None, + ) -> None: + """H2D: scatter the flat pinned ``host_buf`` back into the paged GPU + cache at ``block_ids`` (in-place into the real KV tensors).""" + ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with ctx: + for i, bid in enumerate(block_ids): + base = i * self.bytes_per_block + for seg, off, nb in zip( + self._segments, self._seg_off, self._seg_block_bytes + ): + dst = self._block_bytes_view(seg, bid) + dst.copy_(host_buf[base + off : base + off + nb], non_blocking=True) + + +class _nullctx: + def __enter__(self): + return None + + def __exit__(self, *a): + return False diff --git a/atom/kv_transfer/offload/metadata.py b/atom/kv_transfer/offload/metadata.py new file mode 100644 index 0000000000..f4275e571b --- /dev/null +++ b/atom/kv_transfer/offload/metadata.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Per-request transfer descriptors for the LMCache CPU/NVMe offload connector. + +Ported (type-substituted) from vLLM's ``lmcache_integration/vllm_v1_adapter.py`` +(``LoadSpec`` / ``SaveSpec`` / ``RequestTracker`` / ``ReqMeta``) onto ATOM's +``Sequence`` model. These travel from the scheduler-side connector to the +worker-side connector inside :class:`LMCacheOffloadMetadata`, which subclasses +ATOM's :class:`ConnectorMetadata` so the engine forwards it opaquely through +``process_kvconnector_output`` → ``start_load_kv``. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from atom.kv_transfer.disaggregation.types import ConnectorMetadata + + +@dataclass +class LoadSpec: + """How many tokens to load for a request, split HBM-cached vs LMCache-cached.""" + + # Tokens already resident in ATOM's HBM prefix cache (num_cached_tokens). + hbm_cached_tokens: int + # Total tokens LMCache can supply (>= hbm_cached_tokens). The load fills the + # gap [hbm_cached_tokens, lmcache_cached_tokens). + lmcache_cached_tokens: int + # Set True by update_state_after_alloc once blocks are reserved for the load. + can_load: bool = False + + +@dataclass +class SaveSpec: + """How many leading tokens of a request are already saved to LMCache.""" + + # Tokens at the prefix already persisted (skip these on the next store). + skip_leading_tokens: int + # Set False to suppress the store for this step (e.g. nothing new to save). + can_save: bool = True + + +@dataclass +class LMCacheReqMeta: + """Everything the worker needs to load/save one request's KV this step.""" + + req_id: str + # Token ids covering the prefix being moved (used to derive chunk-256 keys via + # LMCache's ChunkedTokenDatabase). For load: prompt[:lmcache_cached_tokens]; + # for save: computed token ids. + token_ids: list[int] + # The sequence's GPU block table (logical block ids). A chunk spanning token + # range [start, end) maps to blocks block_ids[start // bs : ceil(end / bs)]. + block_ids: list[int] + load_spec: LoadSpec | None = None + save_spec: SaveSpec | None = None + # True on the request's final prefill chunk (store the unaligned tail too). + is_last_prefill: bool = True + + +class LMCacheOffloadMetadata(ConnectorMetadata): + """Connector metadata snapshot for one engine step. + + Subclasses ATOM's :class:`ConnectorMetadata` (so it satisfies the + ``build_connector_meta() -> ConnectorMetadata`` contract and is forwarded + opaquely by the engine) while carrying the richer per-request offload + descriptors the worker consumes in ``start_load_kv``. + """ + + def __init__(self) -> None: + super().__init__() + self.requests: list[LMCacheReqMeta] = [] + # req_ids whose scheduler-side lookup pin should be released this step. + self.lookup_requests_in_step: list[str] = [] + + def add_request(self, meta: LMCacheReqMeta) -> None: + self.requests.append(meta) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 57d4acbe8d..3bc6c3f07f 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -716,8 +716,42 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: skipped_waiting_requests.append(seq) continue + # OFFLOAD fresh-wake: a seq whose CPU/NVMe prefix just finished + # loading into its GPU blocks. Unlike P/D (which jumps straight to + # decode with an injected first token), offload must resume as a + # PREFILL of only the un-loaded SUFFIX: bump num_cached_tokens to the + # loaded count and fall through to admission WITHOUT re-matching or + # re-allocating (blocks were allocated for the whole seq before it + # parked). `offload_resume` also covers the case where a woken + # offload seq did not fit this step and was re-parked as plain + # WAITING -- it prevents re-running get_num_new_matched_tokens and, + # critically, re-calling block_manager.allocate() onto an already + # populated block_table (005's `assert not seq.block_table` crash). + is_offload = self._is_offload_connector() + if waiting_remote_to_waiting_ready and is_offload: + loaded = getattr(seq, "offload_loaded_tokens", None) + logger.debug( + "[OFFLOAD-WAKE] seq %s: loaded=%s prev_cached=%d num_tokens=%d", + seq.id, loaded, seq.num_cached_tokens, seq.num_tokens, + ) + if loaded is not None and loaded > seq.num_cached_tokens: + seq.num_cached_tokens = loaded + seq.offload_loaded = True + waiting_remote_to_waiting_ready = False # not the P/D decode jump + + offload_resume = ( + is_offload + and getattr(seq, "offload_loaded", False) + and len(seq.block_table) > 0 + and seq.num_cached_tokens > 0 + ) + need_to_remove_to_load_kv_async_queue = False - if self.kv_connector is not None and not waiting_remote_to_waiting_ready: + if ( + self.kv_connector is not None + and not waiting_remote_to_waiting_ready + and not offload_resume + ): _ext_tokens, need_to_remove_to_load_kv_async_queue = ( self.kv_connector.get_num_new_matched_tokens(seq) ) @@ -743,43 +777,53 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: self.running.append(seq) continue - # Probe cache hits FIRST so budget check sees the real - # (post-prefix-cache) remaining token count. `can_allocate` - # excludes the last block from cache hits (prefill must forward - # at least one block to produce logits), so num_new_tokens ≥ 1 - # is guaranteed. - num_cached_blocks = self.block_manager.can_allocate(seq) - if num_cached_blocks < 0: - self.waiting.appendleft(seq) - break - - num_new_tokens = ( - seq.num_prompt_tokens - - num_cached_blocks * self.block_manager.block_size - ) - budget_remaining = self.max_num_batched_tokens - num_batched_tokens - if self.enable_chunked_prefill: - chunk = min(num_new_tokens, budget_remaining) + if offload_resume: + # Blocks already held from the pre-park allocate; only re-check + # the batch budget. No re-match / re-allocate / re-park. + num_new_tokens = seq.num_prompt_tokens - seq.num_cached_tokens + budget_remaining = self.max_num_batched_tokens - num_batched_tokens + if self.enable_chunked_prefill: + chunk = min(num_new_tokens, budget_remaining) + else: + if num_new_tokens > budget_remaining and num_batched_tokens > 0: + self.waiting.appendleft(seq) + break + chunk = num_new_tokens else: - if num_new_tokens > budget_remaining and num_batched_tokens > 0: + # Probe cache hits FIRST so budget check sees the real + # (post-prefix-cache) remaining token count. + num_cached_blocks = self.block_manager.can_allocate(seq) + if num_cached_blocks < 0: self.waiting.appendleft(seq) break - chunk = num_new_tokens - assert chunk > 0, ( - f"chunk must be positive: {chunk=}, " - f"{num_new_tokens=}, {budget_remaining=}" - ) - self.block_manager.allocate(seq, num_cached_blocks) + num_new_tokens = ( + seq.num_prompt_tokens + - num_cached_blocks * self.block_manager.block_size + ) + budget_remaining = self.max_num_batched_tokens - num_batched_tokens + if self.enable_chunked_prefill: + chunk = min(num_new_tokens, budget_remaining) + else: + if num_new_tokens > budget_remaining and num_batched_tokens > 0: + self.waiting.appendleft(seq) + break + chunk = num_new_tokens + + self.block_manager.allocate(seq, num_cached_blocks) - if self.kv_connector is not None: - self.kv_connector.update_state_after_alloc(seq) + if self.kv_connector is not None: + self.kv_connector.update_state_after_alloc(seq) - if need_to_remove_to_load_kv_async_queue: - skipped_waiting_requests.append(seq) - seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS - continue + if need_to_remove_to_load_kv_async_queue: + skipped_waiting_requests.append(seq) + seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS + continue + assert chunk > 0, ( + f"chunk must be positive: {chunk=}, " + f"{num_new_tokens=}, {budget_remaining=}" + ) if self.cache_stats: self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) num_batched_tokens += chunk @@ -1211,6 +1255,14 @@ def postprocess( self.running.remove(seq) return finished_seqs + def _is_offload_connector(self) -> bool: + """True when the active KV connector is the CPU/NVMe offload backend. + + Offload wakes a parked seq into a SUFFIX prefill (not the P/D decode + jump). Connectors set ``is_offload = True`` to opt into this path. + """ + return getattr(self.kv_connector, "is_offload", False) + def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool: """Check whether a remote KV transfer for *seq* has completed. From 291b43f1e7f37cbfd466c87f36d6e07a498e1520 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Sat, 30 May 2026 22:44:17 -0500 Subject: [PATCH 02/27] Optimize LMCache offload reload path --- atom/kv_transfer/disaggregation/aggregator.py | 42 +- atom/kv_transfer/disaggregation/base.py | 9 +- atom/kv_transfer/disaggregation/types.py | 21 +- atom/kv_transfer/offload/connector.py | 444 ++++++++++++++---- atom/kv_transfer/offload/gpu_connector.py | 214 ++++++++- atom/kv_transfer/offload/metadata.py | 4 +- atom/kv_transfer/offload/native_stitch.cpp | 81 ++++ atom/kv_transfer/offload/native_stitch.py | 44 ++ atom/model_engine/model_runner.py | 9 +- atom/model_engine/scheduler.py | 97 +++- tests/test_lmcache_offload_connector.py | 271 +++++++++++ 11 files changed, 1114 insertions(+), 122 deletions(-) create mode 100644 atom/kv_transfer/offload/native_stitch.cpp create mode 100644 atom/kv_transfer/offload/native_stitch.py create mode 100644 tests/test_lmcache_offload_connector.py diff --git a/atom/kv_transfer/disaggregation/aggregator.py b/atom/kv_transfer/disaggregation/aggregator.py index bbe74f0f0e..ff11059b8d 100644 --- a/atom/kv_transfer/disaggregation/aggregator.py +++ b/atom/kv_transfer/disaggregation/aggregator.py @@ -18,7 +18,7 @@ import logging -from atom.kv_transfer.disaggregation.types import KVConnectorOutput +from atom.kv_transfer.disaggregation.types import KVConnectorOutput, ReqId logger = logging.getLogger("atom") @@ -48,8 +48,10 @@ def __init__(self, world_size: int = 8) -> None: if world_size <= 0: raise ValueError(f"world_size must be positive, got {world_size}") self._world_size = world_size - self._seen_sending: dict[str, set[int]] = {} - self._seen_recving: dict[str, set[int]] = {} + self._seen_sending: dict[ReqId, set[int]] = {} + self._seen_recving: dict[ReqId, set[int]] = {} + self._seen_recv_failed: dict[ReqId, set[int]] = {} + self._seen_saving: dict[ReqId, set[int]] = {} @property def world_size(self) -> int: @@ -76,15 +78,33 @@ def aggregate(self, worker_outputs: list[KVConnectorOutput]) -> KVConnectorOutpu if wo.finished_recving: for rid in wo.finished_recving: self._seen_recving.setdefault(rid, set()).add(worker_idx) + if wo.failed_recving: + for rid in wo.failed_recving: + self._seen_recv_failed.setdefault(rid, set()).add(worker_idx) + if wo.finished_saving: + for rid in wo.finished_saving: + self._seen_saving.setdefault(rid, set()).add(worker_idx) done_sending = { rid for rid, workers in self._seen_sending.items() if len(workers) >= self._world_size } + failed_recving = set() + recv_ids = set(self._seen_recving) | set(self._seen_recv_failed) + for rid in recv_ids: + done_workers = self._seen_recving.get(rid, set()) + failed_workers = self._seen_recv_failed.get(rid, set()) + if failed_workers and len(done_workers | failed_workers) >= self._world_size: + failed_recving.add(rid) done_recving = { rid for rid, workers in self._seen_recving.items() + if len(workers) >= self._world_size and rid not in failed_recving + } + done_saving = { + rid + for rid, workers in self._seen_saving.items() if len(workers) >= self._world_size } @@ -92,18 +112,32 @@ def aggregate(self, worker_outputs: list[KVConnectorOutput]) -> KVConnectorOutpu del self._seen_sending[rid] for rid in done_recving: del self._seen_recving[rid] + self._seen_recv_failed.pop(rid, None) + for rid in failed_recving: + self._seen_recving.pop(rid, None) + self._seen_recv_failed.pop(rid, None) + for rid in done_saving: + del self._seen_saving[rid] return KVConnectorOutput( finished_sending=done_sending, finished_recving=done_recving, + failed_recving=failed_recving, + finished_saving=done_saving, ) def reset(self) -> None: """Clear all internal tracking state.""" self._seen_sending.clear() self._seen_recving.clear() + self._seen_recv_failed.clear() + self._seen_saving.clear() @property def pending_count(self) -> tuple[int, int]: """Return ``(num_pending_sending, num_pending_recving)``.""" - return len(self._seen_sending), len(self._seen_recving) + return ( + len(self._seen_sending), + len(set(self._seen_recving) | set(self._seen_recv_failed)) + + len(self._seen_saving), + ) diff --git a/atom/kv_transfer/disaggregation/base.py b/atom/kv_transfer/disaggregation/base.py index ca5c306ad5..0ee3a241ad 100644 --- a/atom/kv_transfer/disaggregation/base.py +++ b/atom/kv_transfer/disaggregation/base.py @@ -21,7 +21,7 @@ from abc import ABC, abstractmethod from typing import Any -from atom.kv_transfer.disaggregation.types import ConnectorMetadata +from atom.kv_transfer.disaggregation.types import ConnectorMetadata, KVConnectorOutput class KVConnectorBase(ABC): @@ -48,8 +48,11 @@ def start_load_kv(self, metadata: ConnectorMetadata) -> None: ... @abstractmethod - def get_finished(self) -> tuple[set, set]: - """Return ``(done_sending, done_recving)`` request ID sets. + def get_finished(self) -> tuple[set, set] | KVConnectorOutput: + """Return transfer completion status. + + Older connectors may return ``(done_sending, done_recving)``. Connectors + that need richer semantics can return :class:`KVConnectorOutput`. Called by the worker each engine step to report transfer status. """ diff --git a/atom/kv_transfer/disaggregation/types.py b/atom/kv_transfer/disaggregation/types.py index 46179aaf55..a61a3a72a7 100644 --- a/atom/kv_transfer/disaggregation/types.py +++ b/atom/kv_transfer/disaggregation/types.py @@ -19,7 +19,7 @@ # --------------------------------------------------------------------------- EngineId = str -ReqId = str +ReqId = str | int TransferId = int # --------------------------------------------------------------------------- @@ -59,22 +59,33 @@ class KVConnectorOutput: Attributes: finished_sending: Request IDs whose KV send completed on this worker. finished_recving: Request IDs whose KV receive completed on this worker. + failed_recving: Request IDs whose KV receive failed on this worker. + finished_saving: Request IDs whose local fire-and-forget save completed. expected_finished_count: How many finished notifications should be expected per request (used by the aggregator). """ - finished_sending: set[str] = field(default_factory=set) - finished_recving: set[str] = field(default_factory=set) + finished_sending: set[ReqId] = field(default_factory=set) + finished_recving: set[ReqId] = field(default_factory=set) + failed_recving: set[ReqId] = field(default_factory=set) + finished_saving: set[ReqId] = field(default_factory=set) expected_finished_count: int = 0 def is_empty(self) -> bool: """Return True if no transfers finished on this worker.""" - return not self.finished_sending and not self.finished_recving + return ( + not self.finished_sending + and not self.finished_recving + and not self.failed_recving + and not self.finished_saving + ) def __repr__(self) -> str: return ( f"KVConnectorOutput(sending={self.finished_sending}, " - f"recving={self.finished_recving})" + f"recving={self.finished_recving}, " + f"failed_recving={self.failed_recving}, " + f"finished_saving={self.finished_saving})" ) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 2cc34ca7cf..3415cabc93 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -7,9 +7,11 @@ * **Reuse real LMCache as a storage tier only** — per-rank ``LMCacheEngine`` for its ``StorageManager`` (CPU LRU + NVMe L3) + ``ChunkedTokenDatabase`` (chunk-256 keys). - We bypass ``engine.store/retrieve`` (token-major GPU path can't represent AITER's - swizzle) and instead move **opaque per-block bytes** via :class:`ATOMKVByteCodec` - into pinned ``KV_2LTD``-as-uint8 ``MemoryObj``s. + We bypass ``engine.store/retrieve`` (its token-major GPU path can't represent ATOM's + x-packed KV storage layout — ``K=(nb,H,D//x,bs,x)``, see ``ATOMKVByteCodec`` docstring; + loosely "swizzle", but a persistent storage layout, not LDS bank-swizzle) and instead + move **opaque per-block bytes** via :class:`ATOMKVByteCodec` into pinned + ``KV_2LTD``-as-uint8 ``MemoryObj``s. * **Daemon-after-forward copies** — ``start_load_kv`` only ``submit``s to a single serial copy daemon (ThreadPoolExecutor max_workers=1) and returns immediately, so the worker RPC thread is free for ``forward``; completions are polled in @@ -24,6 +26,7 @@ import logging import os import threading +import time from concurrent.futures import ThreadPoolExecutor import torch @@ -32,6 +35,7 @@ KVConnectorBase, KVConnectorSchedulerBase, ) +from atom.kv_transfer.disaggregation.types import KVConnectorOutput, ReqId from atom.kv_transfer.offload import config as offcfg from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec from atom.kv_transfer.offload.metadata import ( @@ -103,8 +107,10 @@ def __init__(self, config) -> None: ) self._tls = threading.local() # per-thread copy stream self._lock = threading.Lock() - self._done_load: set[str] = set() - self._done_save: set[str] = set() + self._done_load: set[ReqId] = set() + self._done_save: set[ReqId] = set() + self._failed_load: set[ReqId] = set() + self._load_active = threading.Event() self._engine = None self._sm = None @@ -142,9 +148,9 @@ def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: def _logged_lookup(*a, **k): r = _orig_lookup(*a, **k) h = k.get("hashes") - logger.info("[ENGINE.LOOKUP] rank=%s lookup_id=%s nhashes=%s first3=%s -> %s", - _rk, k.get("lookup_id"), (len(h) if h is not None else None), - (list(h[:3]) if h else None), r) + logger.debug("[ENGINE.LOOKUP] rank=%s lookup_id=%s nhashes=%s first3=%s -> %s", + _rk, k.get("lookup_id"), (len(h) if h is not None else None), + (list(h[:3]) if h else None), r) return r self._engine.lookup = _logged_lookup @@ -158,8 +164,9 @@ def _logged_lookup(*a, **k): logger.warning("LMCache offload: lookup server not started: %s", e) logger.info( - "LMCache offload worker rank=%d: bytes_per_block=%d chunk=%d save=%s load=%s", - rank, self._codec.bytes_per_block, self.chunk_size, + "LMCache offload worker rank=%d: bytes_per_block=%d chunk=%d " + "codec_layout=%s save=%s load=%s", + rank, self._codec.bytes_per_block, self.chunk_size, self._codec.layout, self._do_save, self._do_load, ) @@ -169,21 +176,41 @@ def start_load_kv(self, metadata) -> None: return for req in metadata.requests: if req.load_spec is not None and self._do_load: - self._load_executor.submit(self._guard, self._do_load_req, req) + self._load_executor.submit(self._guard, "load", self._do_load_req, req) if req.save_spec is not None and self._do_save: - self._save_executor.submit(self._guard, self._do_save_req, req) - - def _guard(self, fn, req) -> None: + self._save_executor.submit(self._guard, "save", self._do_save_req, req) + + def _guard(self, kind: str, fn, req) -> None: + load_active = getattr(self, "_load_active", None) + if kind == "load" and load_active is None: + load_active = threading.Event() + self._load_active = load_active + if kind == "load": + load_active.set() try: fn(req) except Exception: logger.exception("LMCache offload: %s failed for %s", fn.__name__, req.req_id) - # Wake the seq anyway so it is not stuck parked; scheduler re-derives - # how much is actually cached (load) / proceeds (save). + if kind == "load": + self._lookup_unpin(req.req_id) with self._lock: - (self._done_load if fn is self._do_load_req else self._done_save).add( - req.req_id - ) + if kind == "load": + self._failed_load.add(req.req_id) + else: + # A failed save should not keep blocks pinned forever. The + # request simply loses this offload opportunity. + self._done_save.add(req.req_id) + finally: + if kind == "load": + load_active.clear() + + def _lookup_unpin(self, req_id) -> None: + if getattr(self, "_engine", None) is None: + return + try: + self._engine.lookup_unpin([str(req_id)]) # LMCache pin keyed by str id + except Exception: + pass def _stream(self) -> torch.cuda.Stream: """A CUDA stream owned by the calling copy-daemon thread (lazily made).""" @@ -193,49 +220,194 @@ def _stream(self) -> torch.cuda.Stream: self._tls.stream = s return s + def _host_tmp(self, nbytes: int) -> torch.Tensor: + """Pinned CPU scratch buffer owned by the calling copy-daemon thread.""" + buf = getattr(self._tls, "host_tmp", None) + if buf is None or int(buf.numel()) < int(nbytes): + try: + buf = torch.empty((int(nbytes),), dtype=torch.uint8, pin_memory=True) + except RuntimeError: + logger.warning( + "LMCache offload: pinned host scratch allocation failed; " + "falling back to pageable CPU memory", + exc_info=True, + ) + buf = torch.empty((int(nbytes),), dtype=torch.uint8) + self._tls.host_tmp = buf + return buf[: int(nbytes)] + + def _pause_save_for_load(self, stream: torch.cuda.Stream) -> None: + """Let critical-path loads drain before fire-and-forget save copies.""" + load_active = getattr(self, "_load_active", None) + if load_active is None or not load_active.is_set(): + return + stream.synchronize() + while load_active.is_set(): + time.sleep(0.001) + def _block_ids(self, req: LMCacheReqMeta, start: int, end: int) -> list[int]: return req.block_ids[start // self.block_size : _cdiv(end, self.block_size)] + def _profile_enabled(self) -> bool: + return os.environ.get("OFFLOAD_PROFILE", "1").lower() not in ( + "0", + "false", + "no", + "off", + ) + # -- copy daemon thread ---------------------------------------------- def _do_load_req(self, req: LMCacheReqMeta) -> None: ls = req.load_spec assert ls is not None stream = self._stream() - hbm = (ls.hbm_cached_tokens // self.chunk_size) * self.chunk_size + hbm = int(ls.hbm_cached_tokens) toks = req.token_ids[: ls.lmcache_cached_tokens] + t_total0 = time.perf_counter() + if int(ls.lmcache_cached_tokens) <= hbm: + self._lookup_unpin(req.req_id) + with self._lock: + self._done_load.add(req.req_id) + return mask = torch.ones(len(toks), dtype=torch.bool) mask[:hbm] = False + t0 = time.perf_counter() chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) + process_ms = (time.perf_counter() - t0) * 1000 logger.debug("offload _do_load req=%s hbm=%d lmc=%d chunks=%d", req.req_id, hbm, ls.lmcache_cached_tokens, len(chunks)) - # All-or-nothing: a partial load would let attention read uninitialized - # blocks. If any chunk is gone (evicted between lookup and load), skip the - # whole load — the seq wakes and re-prefills the suffix (loaded 0). + # All-or-nothing above the HBM prefix: a partial load would let attention + # read uninitialized blocks, and a chunk that overlaps an HBM-cache hit + # could overwrite shared prefix-cache blocks. In either case the seq + # wakes and re-prefills from its HBM floor. + if not chunks: + logger.warning("LMCache offload: no loadable chunks req=%s; re-prefill", + req.req_id) + self._lookup_unpin(req.req_id) + with self._lock: + self._failed_load.add(req.req_id) + return + for (s, _e, _key) in chunks: + if s < hbm: + logger.warning( + "LMCache offload: chunk overlaps HBM prefix req=%s hbm=%d " + "chunk_start=%d; re-prefill", + req.req_id, hbm, s, + ) + self._lookup_unpin(req.req_id) + with self._lock: + self._failed_load.add(req.req_id) + return + contains_ms = 0.0 for (_s, _e, key) in chunks: + t0 = time.perf_counter() if not self._sm.contains(key): + contains_ms += (time.perf_counter() - t0) * 1000 logger.warning("LMCache offload: load miss req=%s; re-prefill", req.req_id) + self._lookup_unpin(req.req_id) with self._lock: - self._done_load.add(req.req_id) - return - - for (s, e, key) in chunks: - mo = self._sm.get(key) - if mo is None: - with self._lock: - self._done_load.add(req.req_id) + self._failed_load.add(req.req_id) return - self._codec.host_to_gpu(mo.tensor, self._block_ids(req, s, e), stream) + contains_ms += (time.perf_counter() - t0) * 1000 + + loaded_objs = [] + get_ms = 0.0 + host_alloc_ms = 0.0 + stitch_ms = 0.0 + h2d_submit_ms = 0.0 + sync_ms = 0.0 + nblocks = 0 + nbytes = 0 + copy_calls = 0 + chunk_bids: list[list[int]] = [] + try: + for (s, e, key) in chunks: + t0 = time.perf_counter() + mo = self._sm.get(key) + get_ms += (time.perf_counter() - t0) * 1000 + if mo is None: + t0 = time.perf_counter() + stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + for loaded_mo in loaded_objs: + loaded_mo.ref_count_down() + self._lookup_unpin(req.req_id) + with self._lock: + self._failed_load.add(req.req_id) + return + loaded_objs.append(mo) + bids = self._block_ids(req, s, e) + chunk_bids.append(bids) + nblocks += len(bids) + nbytes += len(bids) * self._codec.bytes_per_block + if self._codec.layout != "segment_indexed": + copy_calls += self._codec.copy_calls_for_block_ids(bids) + t0 = time.perf_counter() + self._codec.host_to_gpu(mo.tensor, bids, stream) + h2d_submit_ms += (time.perf_counter() - t0) * 1000 + if self._codec.layout == "segment_indexed": + all_bids = [bid for bids in chunk_bids for bid in bids] + copy_calls = self._codec.copy_calls_for_block_ids(all_bids) + t0 = time.perf_counter() + req_buf = self._host_tmp(nbytes) + host_alloc_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self._codec.stitch_chunk_buffers( + req_buf, + [mo.tensor for mo in loaded_objs], + [len(bids) for bids in chunk_bids], + ) + stitch_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self._codec.host_to_gpu(req_buf, all_bids, stream) + h2d_submit_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + except Exception: + try: + t0 = time.perf_counter() + stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + finally: + for loaded_mo in loaded_objs: + loaded_mo.ref_count_down() + self._lookup_unpin(req.req_id) + raise + for mo in loaded_objs: mo.ref_count_down() - stream.synchronize() # Release the lookup pin (taken by the scheduler's LookupClient.lookup) # now that the chunks are safely in GPU; lets the pool evict them later. - try: - self._engine.lookup_unpin([str(req.req_id)]) # LMCache pin keyed by str id - except Exception: - pass + self._lookup_unpin(req.req_id) with self._lock: self._done_load.add(req.req_id) + if self._profile_enabled(): + total_ms = (time.perf_counter() - t_total0) * 1000 + logger.info( + "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " + "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " + "layout=%s process_ms=%.2f contains_ms=%.2f get_ms=%.2f " + "host_alloc_ms=%.2f stitch_ms=%.2f h2d_submit_ms=%.2f " + "sync_ms=%.2f total_ms=%.2f", + getattr(self, "_rank", "?"), + req.req_id, + hbm, + ls.lmcache_cached_tokens, + len(chunks), + nblocks, + nbytes / 1024**3, + copy_calls, + self._codec.layout, + process_ms, + contains_ms, + get_ms, + host_alloc_ms, + stitch_ms, + h2d_submit_ms, + sync_ms, + total_ms, + ) logger.info("offload _do_load DONE req=%s", req.req_id) def _do_save_req(self, req: LMCacheReqMeta) -> None: @@ -253,48 +425,122 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: self._done_save.add(req.req_id) return + t_total0 = time.perf_counter() mask = torch.ones(len(toks), dtype=torch.bool) mask[:skip] = False + t0 = time.perf_counter() chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) + process_ms = (time.perf_counter() - t0) * 1000 keys, objs, already = [], [], 0 - for (s, e, key) in chunks: - if self._sm.contains(key): # already offloaded → skip wasted D2H - already += 1 - continue - bids = self._block_ids(req, s, e) - nbytes = len(bids) * self._codec.bytes_per_block - mo = self._sm.allocate(torch.Size((nbytes,)), torch.uint8, - fmt=MemoryFormat.KV_2LTD) - if mo is None: # pool under pressure; stop here - break - # D2H on this thread's dedicated copy stream (off the compute stream). - self._codec.gpu_to_host(mo.tensor, bids, stream) - keys.append(key) - objs.append(mo) - - if keys: - stream.synchronize() # stream-specific - self._sm.batched_put(keys, objs) + put_started = False + contains_ms = 0.0 + alloc_ms = 0.0 + d2h_submit_ms = 0.0 + sync_ms = 0.0 + put_ms = 0.0 + nblocks = 0 + total_nbytes = 0 + copy_calls = 0 + try: + for (s, e, key) in chunks: + self._pause_save_for_load(stream) + t0 = time.perf_counter() + if self._sm.contains(key): # already offloaded → skip wasted D2H + contains_ms += (time.perf_counter() - t0) * 1000 + already += 1 + continue + contains_ms += (time.perf_counter() - t0) * 1000 + bids = self._block_ids(req, s, e) + chunk_nbytes = len(bids) * self._codec.bytes_per_block + t0 = time.perf_counter() + mo = self._sm.allocate(torch.Size((chunk_nbytes,)), torch.uint8, + fmt=MemoryFormat.KV_2LTD) + alloc_ms += (time.perf_counter() - t0) * 1000 + if mo is None: # pool under pressure; stop here + break + keys.append(key) + objs.append(mo) + nblocks += len(bids) + total_nbytes += chunk_nbytes + copy_calls += self._codec.copy_calls_for_block_ids(bids) + # D2H on this thread's dedicated copy stream (off the compute stream). + t0 = time.perf_counter() + self._codec.gpu_to_host(mo.tensor, bids, stream) + d2h_submit_ms += (time.perf_counter() - t0) * 1000 + + if keys: + t0 = time.perf_counter() + stream.synchronize() # stream-specific + sync_ms += (time.perf_counter() - t0) * 1000 + put_started = True + t0 = time.perf_counter() + self._sm.batched_put(keys, objs) + put_ms += (time.perf_counter() - t0) * 1000 + except Exception: + if not put_started: + try: + t0 = time.perf_counter() + stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + finally: + for mo in objs: + mo.ref_count_down() + raise with self._lock: self._done_save.add(req.req_id) - _kh = [getattr(k, "chunk_hash", None) for k in keys[:2]] - _contains = [bool(self._sm.contains(k)) for k in keys[:2]] - logger.info("[OFFLOAD-SAVE] rank=%s req=%s toks=%d chunks=%d stored=%d already=%d " - "chunkhash2=%s contains=%s", - self._rank, req.req_id, len(toks), len(chunks), len(keys), - already, _kh, _contains) + if self._profile_enabled(): + total_ms = (time.perf_counter() - t_total0) * 1000 + logger.info( + "[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d chunks=%d " + "stored=%d already=%d blocks=%d bytes=%.3fGiB copy_calls=%d " + "layout=%s process_ms=%.2f contains_ms=%.2f alloc_ms=%.2f " + "d2h_submit_ms=%.2f sync_ms=%.2f put_ms=%.2f total_ms=%.2f", + getattr(self, "_rank", "?"), + req.req_id, + len(toks), + len(chunks), + len(keys), + already, + nblocks, + total_nbytes / 1024**3, + copy_calls, + self._codec.layout, + process_ms, + contains_ms, + alloc_ms, + d2h_submit_ms, + sync_ms, + put_ms, + total_ms, + ) + if logger.isEnabledFor(logging.DEBUG): + _kh = [getattr(k, "chunk_hash", None) for k in keys[:2]] + _contains = [bool(self._sm.contains(k)) for k in keys[:2]] + logger.debug("[OFFLOAD-SAVE] rank=%s req=%s toks=%d chunks=%d stored=%d already=%d " + "chunkhash2=%s contains=%s", + self._rank, req.req_id, len(toks), len(chunks), len(keys), + already, _kh, _contains) # -- per-step (RPC thread, post-forward): poll completions ------------ - def get_finished(self) -> tuple[set, set]: - # (finished_sending, finished_recving). Offload SAVES are fire-and-forget - # (they don't free blocks), so finished_sending is ALWAYS empty; only - # completed LOADS are reported, to wake parked seqs for suffix prefill. + def get_finished(self) -> KVConnectorOutput: + # Offload uses extended completion states: + # - finished_recving wakes successfully loaded requests. + # - failed_recving wakes them for recompute using already allocated blocks. + # - finished_saving releases blocks whose free was deferred during save. with self._lock: dl = set(self._done_load) + fl = set(self._failed_load) + ds = set(self._done_save) self._done_save.clear() self._done_load.clear() - return set(), dl + self._failed_load.clear() + return KVConnectorOutput( + finished_sending=set(), + finished_recving=dl, + failed_recving=fl, + finished_saving=ds, + ) def get_finished_recv_blocks(self) -> list[int]: # Local CUDA copies are ordered by the copy stream + synchronize() before @@ -329,6 +575,7 @@ def __init__(self, config) -> None: # prefix is stored to LMCache once prefill computes it # (seq.prefix_hashes_published flips True), chunk by chunk. self._save_tracker: dict[str, list] = {} + self._save_inflight: set[str] = set() self._lookup_in_step: list[str] = [] try: @@ -353,27 +600,35 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: except Exception: logger.exception("LMCache offload lookup failed for seq %s", seq.id) return 0, False - _lh = None - try: - tdb = getattr(self._lookup_client, "token_database", None) - if tdb is not None: - _lh = [k for (_s, _e, k) in list( - tdb.process_tokens(token_ids, make_key=False))[:3]] - except Exception as e: - _lh = f"err:{e}" - logger.info("[OFFLOAD-LOOKUP] seq=%s num_prompt=%d hbm_cached=%d hit=%s lookuphash3=%s", - seq.id, num_prompt, int(seq.num_cached_tokens), hit, _lh) + if logger.isEnabledFor(logging.DEBUG): + _lh = None + try: + tdb = getattr(self._lookup_client, "token_database", None) + if tdb is not None: + _lh = [k for (_s, _e, k) in list( + tdb.process_tokens(token_ids, make_key=False))[:3]] + except Exception as e: + _lh = f"err:{e}" + logger.debug("[OFFLOAD-LOOKUP] seq=%s num_prompt=%d hbm_cached=%d hit=%s lookuphash3=%s", + seq.id, num_prompt, int(seq.num_cached_tokens), hit, _lh) if not hit: return 0, False - self._lookup_in_step.append(str(seq.id)) - need = int(hit) - int(seq.num_cached_tokens) - if int(hit) == num_prompt: # full-prompt hit → recompute last token - need -= 1 + sid = str(seq.id) + hit = int(hit) + if hit == num_prompt: # full-prompt hit → recompute last token + hit -= 1 + need = hit - int(seq.num_cached_tokens) if need <= 0: + if self._lookup_client is not None: + try: + self._lookup_client.clear_lookup_status(sid) + except Exception: + pass return 0, False - self._load_specs[str(seq.id)] = LoadSpec( + self._lookup_in_step.append(sid) + self._load_specs[sid] = LoadSpec( hbm_cached_tokens=int(seq.num_cached_tokens), - lmcache_cached_tokens=int(hit), + lmcache_cached_tokens=hit, can_load=False, ) return need, True # True => park in WAITING_FOR_REMOTE_KVS @@ -381,8 +636,8 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: def update_state_after_alloc(self, seq) -> None: sid = str(seq.id) ls = self._load_specs.get(sid) - logger.info("[OFFLOAD-ALLOC] seq=%s ls_found=%s num_cached_now=%s", - seq.id, ls is not None, int(getattr(seq, "num_cached_tokens", -1))) + logger.debug("[OFFLOAD-ALLOC] seq=%s ls_found=%s num_cached_now=%s", + seq.id, ls is not None, int(getattr(seq, "num_cached_tokens", -1))) if ls is not None: ls.can_load = True self._reqs_need_recv[sid] = seq @@ -397,12 +652,12 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: self._lookup_in_step = [] # Loads - logger.info("[OFFLOAD-BUILD] reqs_need_recv=%d", len(self._reqs_need_recv)) + logger.debug("[OFFLOAD-BUILD] reqs_need_recv=%d", len(self._reqs_need_recv)) for sid, seq in self._reqs_need_recv.items(): ls = self._load_specs.pop(sid, None) if ls is None or not ls.can_load: - logger.info("[OFFLOAD-LOAD-SKIP] seq=%s ls=%s can_load=%s", - sid, ls is not None, getattr(ls, "can_load", None)) + logger.debug("[OFFLOAD-LOAD-SKIP] seq=%s ls=%s can_load=%s", + sid, ls is not None, getattr(ls, "can_load", None)) continue # ★ Use the REAL HBM-cached count as the load floor. # get_num_new_matched_tokens runs BEFORE the prefix-cache match in @@ -442,8 +697,8 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: aligned = (int(seq.num_prompt_tokens) // chunk) * chunk if aligned <= saved: continue - logger.info("[OFFLOAD-SAVE-EMIT] seq=%s num_prompt=%d aligned=%d saved=%d", - seq.id, int(seq.num_prompt_tokens), aligned, saved) + logger.debug("[OFFLOAD-SAVE-EMIT] seq=%s num_prompt=%d aligned=%d saved=%d", + seq.id, int(seq.num_prompt_tokens), aligned, saved) meta.add_request(LMCacheReqMeta( req_id=seq.id, token_ids=list(seq.token_ids[:aligned]), @@ -451,9 +706,20 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: save_spec=SaveSpec(skip_leading_tokens=saved, can_save=True), )) entry[1] = aligned + self._save_inflight.add(sid) self._reqs_need_recv.clear() return meta + def should_defer_free(self, seq) -> bool: + return str(seq.id) in self._save_inflight + + def save_finished(self, req_id) -> None: + self._save_inflight.discard(str(req_id)) + + def load_failed(self, req_id) -> None: + self._load_specs.pop(str(req_id), None) + self._reqs_need_recv.pop(str(req_id), None) + def request_finished(self, seq) -> None: sid = str(seq.id) self._load_specs.pop(sid, None) diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 01e6c68af2..321996a353 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -6,8 +6,13 @@ Why a byte codec instead of an LMCache ``GPUConnectorInterface`` subclass: LMCache's ``engine.store/retrieve`` GPU path only emits token-major formats -(``KV_2LTD`` etc.) via ``normalize_kv_and_discover_format``, which rejects -AITER's swizzled K layout ``(nb, H, D//x, bs, x)`` and strided V ``(nb, H, D, bs)``. +(``KV_2LTD`` etc.) via ``normalize_kv_and_discover_format``, which only accepts the +clean NHD/HND family and rejects ATOM's **x-packed, head-major** K layout +``(nb, H, D//x, bs, x)`` and strided V ``(nb, H, D, bs)`` (``x = 16 // elem``; verified +``atom/model_ops/attentions/aiter_attention.py:488-502``). NB: this is a *persistent +HBM storage layout*, NOT the transient LDS bank-conflict "swizzle"; we call it "swizzle" +only as loose shorthand. It is also specific to this ATOM aiter path — stock vLLM's aiter +FA backend (``rocm_aiter_fa``) uses the clean token-major ``(2,nb,bs,H,D)`` LMCache handles. We therefore bypass that path: we store **opaque per-block bytes** (byte-identical round-trip — the attention kernel reads back its own layout) and drive LMCache only as a storage tier (``StorageManager`` + ``ChunkedTokenDatabase``). @@ -23,8 +28,14 @@ from __future__ import annotations +import logging +import os +import threading + import torch +logger = logging.getLogger("atom") + class ATOMKVByteCodec: """Per-block byte mover between paged GPU KV tensors and a flat host buffer.""" @@ -61,6 +72,40 @@ def __init__(self, kv_caches: dict) -> None: acc += nb self.bytes_per_block: int = acc self.num_blocks: int = int(self._segments[0].shape[0]) + self.layout = os.environ.get("OFFLOAD_CODEC_LAYOUT", "block").lower() + if self.layout not in ("block", "segment", "segment_indexed"): + self.layout = "block" + self._tls = threading.local() + self._native_stitch = None + if ( + self.layout == "segment_indexed" + and os.environ.get("OFFLOAD_NATIVE_STITCH", "0").lower() + not in ("0", "false", "no", "off") + ): + try: + from atom.kv_transfer.offload import native_stitch + + native_stitch.load_extension() + self._native_stitch = native_stitch.stitch_chunk_buffers + except Exception: + logger.warning( + "ATOMKVByteCodec: native stitch unavailable; using torch stitch", + exc_info=True, + ) + + @property + def segments_per_block(self) -> int: + return len(self._segments) + + def copy_calls_for_blocks(self, nblocks: int) -> int: + return int(nblocks) * len(self._segments) + + def copy_calls_for_block_ids(self, block_ids: list[int]) -> int: + if self.layout == "block": + return self.copy_calls_for_blocks(len(block_ids)) + if self.layout == "segment_indexed": + return len(self._segments) * 2 + return len(self._segments) * len(list(self._contiguous_runs(block_ids))) # -- helpers ---------------------------------------------------------- @staticmethod @@ -73,6 +118,109 @@ def _block_bytes_view(seg: torch.Tensor, block_id: int) -> torch.Tensor: raise RuntimeError("ATOMKVByteCodec: block slice not contiguous") return blk.reshape(-1).view(torch.uint8) + @staticmethod + def _blocks_bytes_view( + seg: torch.Tensor, + block_id: int, + nblocks: int, + ) -> torch.Tensor: + """Flat ``uint8`` view of a contiguous block range (no copy).""" + blk = seg[block_id : block_id + nblocks] + if not blk.is_contiguous(): + raise RuntimeError("ATOMKVByteCodec: block range not contiguous") + return blk.reshape(-1).view(torch.uint8) + + @staticmethod + def _contiguous_runs(block_ids: list[int]): + """Yield ``(logical_start, physical_start, run_len)`` for increasing + physical block-id runs in logical order.""" + if not block_ids: + return + logical_start = 0 + physical_start = block_ids[0] + prev = block_ids[0] + run_len = 1 + for logical_idx, bid in enumerate(block_ids[1:], start=1): + if bid == prev + 1: + prev = bid + run_len += 1 + continue + yield logical_start, physical_start, run_len + logical_start = logical_idx + physical_start = bid + prev = bid + run_len = 1 + yield logical_start, physical_start, run_len + + def _segment_bases(self, nblocks: int) -> list[int]: + bases = [] + acc = 0 + for nb in self._seg_block_bytes: + bases.append(acc) + acc += nb * nblocks + return bases + + def stitch_chunk_buffers( + self, + dst: torch.Tensor, + chunk_buffers: list[torch.Tensor], + chunk_block_counts: list[int], + ) -> None: + """CPU-side stitch from per-LMCache-chunk segment-major buffers into one + request-level segment-major buffer. + + Each stored chunk is laid out as ``[seg0 chunk_blocks | seg1 ...]``. A + single request-level indexed H2D scatter expects + ``[seg0 all_blocks | seg1 all_blocks | ...]``. + """ + if self._native_stitch is not None: + self._native_stitch( + dst, + chunk_buffers, + chunk_block_counts, + self._seg_block_bytes, + ) + return + total_blocks = sum(chunk_block_counts) + dst_bases = self._segment_bases(total_blocks) + src_bases_by_chunk = [ + self._segment_bases(nblocks) for nblocks in chunk_block_counts + ] + for seg_idx, (dst_base, nb) in enumerate( + zip(dst_bases, self._seg_block_bytes) + ): + parts = [ + src[ + bases[seg_idx] : bases[seg_idx] + nblocks * nb + ] + for src, bases, nblocks in zip( + chunk_buffers, src_bases_by_chunk, chunk_block_counts + ) + ] + torch.cat( + parts, + out=dst[dst_base : dst_base + total_blocks * nb], + ) + + def _tmp_bytes(self, seg: torch.Tensor, nblocks: int) -> torch.Tensor: + elems = int(seg[0].numel()) * seg.element_size() + key = (str(seg.device), "uint8", elems, int(nblocks)) + cache = getattr(self._tls, "tmp", None) + if cache is None: + cache = {} + self._tls.tmp = cache + tmp = cache.get(key) + if tmp is None: + tmp = torch.empty((nblocks, elems), dtype=torch.uint8, device=seg.device) + cache[key] = tmp + return tmp + + @staticmethod + def _segment_bytes_matrix(seg: torch.Tensor) -> torch.Tensor: + if not seg.is_contiguous(): + raise RuntimeError("ATOMKVByteCodec: segment tensor not contiguous") + return seg.reshape(seg.shape[0], -1).view(torch.uint8) + # -- public API ------------------------------------------------------- def gpu_to_host( self, @@ -84,6 +232,36 @@ def gpu_to_host( pinned ``host_buf`` (uint8, length == len(block_ids) * bytes_per_block).""" ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() with ctx: + if self.layout == "segment_indexed": + idx = torch.tensor( + block_ids, dtype=torch.long, device=self._segments[0].device + ) + bases = self._segment_bases(len(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + mat = self._segment_bytes_matrix(seg) + tmp = self._tmp_bytes(seg, len(block_ids)) + torch.index_select(mat, 0, idx, out=tmp) + host_buf[base : base + len(block_ids) * nb].copy_( + tmp.reshape(-1), non_blocking=True + ) + return + + if self.layout == "segment": + bases = self._segment_bases(len(block_ids)) + runs = list(self._contiguous_runs(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + for logical_start, physical_start, run_len in runs: + src = self._blocks_bytes_view(seg, physical_start, run_len) + dst = base + logical_start * nb + host_buf[dst : dst + run_len * nb].copy_( + src, non_blocking=True + ) + return + for i, bid in enumerate(block_ids): base = i * self.bytes_per_block for seg, off, nb in zip( @@ -102,6 +280,38 @@ def host_to_gpu( cache at ``block_ids`` (in-place into the real KV tensors).""" ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() with ctx: + if self.layout == "segment_indexed": + idx = torch.tensor( + block_ids, dtype=torch.long, device=self._segments[0].device + ) + bases = self._segment_bases(len(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + mat = self._segment_bytes_matrix(seg) + tmp = self._tmp_bytes(seg, len(block_ids)) + tmp.copy_( + host_buf[base : base + len(block_ids) * nb].reshape_as(tmp), + non_blocking=True, + ) + mat.index_copy_(0, idx, tmp) + return + + if self.layout == "segment": + bases = self._segment_bases(len(block_ids)) + runs = list(self._contiguous_runs(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + for logical_start, physical_start, run_len in runs: + dst = self._blocks_bytes_view(seg, physical_start, run_len) + src = base + logical_start * nb + dst.copy_( + host_buf[src : src + run_len * nb], + non_blocking=True, + ) + return + for i, bid in enumerate(block_ids): base = i * self.bytes_per_block for seg, off, nb in zip( diff --git a/atom/kv_transfer/offload/metadata.py b/atom/kv_transfer/offload/metadata.py index f4275e571b..831dcc9658 100644 --- a/atom/kv_transfer/offload/metadata.py +++ b/atom/kv_transfer/offload/metadata.py @@ -15,7 +15,7 @@ from dataclasses import dataclass -from atom.kv_transfer.disaggregation.types import ConnectorMetadata +from atom.kv_transfer.disaggregation.types import ConnectorMetadata, ReqId @dataclass @@ -45,7 +45,7 @@ class SaveSpec: class LMCacheReqMeta: """Everything the worker needs to load/save one request's KV this step.""" - req_id: str + req_id: ReqId # Token ids covering the prefix being moved (used to derive chunk-256 keys via # LMCache's ChunkedTokenDatabase). For load: prompt[:lmcache_cached_tokens]; # for save: computed token ids. diff --git a/atom/kv_transfer/offload/native_stitch.cpp b/atom/kv_transfer/offload/native_stitch.cpp new file mode 100644 index 0000000000..b81022e354 --- /dev/null +++ b/atom/kv_transfer/offload/native_stitch.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +void stitch_chunk_buffers( + torch::Tensor dst, + std::vector chunk_buffers, + std::vector chunk_block_counts, + std::vector seg_block_bytes) { + TORCH_CHECK(dst.device().is_cpu(), "dst must be a CPU tensor"); + TORCH_CHECK(dst.dtype() == torch::kUInt8, "dst must be uint8"); + TORCH_CHECK(dst.is_contiguous(), "dst must be contiguous"); + TORCH_CHECK( + chunk_buffers.size() == chunk_block_counts.size(), + "chunk_buffers and chunk_block_counts size mismatch"); + + const int64_t nchunks = static_cast(chunk_buffers.size()); + const int64_t nsegs = static_cast(seg_block_bytes.size()); + int64_t total_blocks = 0; + for (int64_t nblocks : chunk_block_counts) { + TORCH_CHECK(nblocks >= 0, "chunk block count must be non-negative"); + total_blocks += nblocks; + } + + std::vector dst_bases(nsegs); + int64_t acc = 0; + for (int64_t seg = 0; seg < nsegs; ++seg) { + const int64_t nb = seg_block_bytes[seg]; + TORCH_CHECK(nb >= 0, "segment byte count must be non-negative"); + dst_bases[seg] = acc; + acc += nb * total_blocks; + } + TORCH_CHECK(dst.numel() >= acc, "dst is smaller than stitched output"); + + std::vector src_ptrs(nchunks); + std::vector src_offsets(nchunks * nsegs); + for (int64_t c = 0; c < nchunks; ++c) { + const auto& src = chunk_buffers[c]; + TORCH_CHECK(src.device().is_cpu(), "chunk buffer must be a CPU tensor"); + TORCH_CHECK(src.dtype() == torch::kUInt8, "chunk buffer must be uint8"); + TORCH_CHECK(src.is_contiguous(), "chunk buffer must be contiguous"); + src_ptrs[c] = src.data_ptr(); + + int64_t src_acc = 0; + const int64_t nblocks = chunk_block_counts[c]; + for (int64_t seg = 0; seg < nsegs; ++seg) { + src_offsets[c * nsegs + seg] = src_acc; + src_acc += seg_block_bytes[seg] * nblocks; + } + TORCH_CHECK(src.numel() >= src_acc, "chunk buffer is smaller than expected"); + } + + auto* dst_ptr = dst.data_ptr(); + at::parallel_for(0, nsegs, 1, [&](int64_t begin, int64_t end) { + for (int64_t seg = begin; seg < end; ++seg) { + const int64_t nb = seg_block_bytes[seg]; + int64_t logical_block_start = 0; + for (int64_t c = 0; c < nchunks; ++c) { + const int64_t nblocks = chunk_block_counts[c]; + const int64_t nbytes = nblocks * nb; + if (nbytes > 0) { + std::memcpy( + dst_ptr + dst_bases[seg] + logical_block_start * nb, + src_ptrs[c] + src_offsets[c * nsegs + seg], + static_cast(nbytes)); + } + logical_block_start += nblocks; + } + } + }); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("stitch_chunk_buffers", &stitch_chunk_buffers); +} diff --git a/atom/kv_transfer/offload/native_stitch.py b/atom/kv_transfer/offload/native_stitch.py new file mode 100644 index 0000000000..c9c2ac5e70 --- /dev/null +++ b/atom/kv_transfer/offload/native_stitch.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Optional native CPU stitch for LMCache chunk buffers. + +This is deliberately a small C++ CPU extension, not a HIP op: current profiles +show H2D at tens of milliseconds, while Python/Torch host repacking takes more +than a second for MiniMax-M2.5 32K. +""" + +from __future__ import annotations + +from pathlib import Path + +from torch.utils.cpp_extension import load + + +_EXT = None + + +def _load_ext(): + global _EXT + if _EXT is None: + src = Path(__file__).with_name("native_stitch.cpp") + _EXT = load( + name="atom_lmcache_native_stitch", + sources=[str(src)], + extra_cflags=["-O3"], + verbose=False, + ) + return _EXT + + +def load_extension() -> None: + _load_ext() + + +def stitch_chunk_buffers(dst, chunk_buffers, chunk_block_counts, seg_block_bytes) -> None: + _load_ext().stitch_chunk_buffers( + dst, + chunk_buffers, + [int(x) for x in chunk_block_counts], + [int(x) for x in seg_block_bytes], + ) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 15d214798c..ebd2823712 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -2044,8 +2044,13 @@ def async_proc_aggregation(self) -> KVConnectorOutput: """Collect finished send/recv status from the KV connector.""" connector = get_kvconnector() if connector is None: - return KVConnectorOutput(finished_sending=[], finished_recving=[]) - done_sending, done_recving = connector.get_finished() + return KVConnectorOutput() + + finished = connector.get_finished() + if isinstance(finished, KVConnectorOutput): + return finished + + done_sending, done_recving = finished return KVConnectorOutput( finished_sending=done_sending, finished_recving=done_recving diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 3bc6c3f07f..ac72b80953 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -428,6 +428,7 @@ def __init__(self, config: Config): # KV transfer bookkeeping self.finished_recving_kv_req_ids: list[int] = [] + self.failed_recving_kv_req_ids: list[int] = [] self.deferred_free_blocks: dict[int, Sequence] = {} # Scheduling delay for batching efficiency @@ -707,14 +708,24 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: # KV Transfer: skip request if still waiting for remote KVs waiting_remote_to_waiting_ready = False if seq.status == SequenceStatus.WAITING_FOR_REMOTE_KVS: - waiting_remote_to_waiting_ready = self._update_waiting_for_remote_kv( - seq - ) - if waiting_remote_to_waiting_ready: + if self._pop_req_id(self.failed_recving_kv_req_ids, seq.id): + if self.kv_connector is not None and hasattr( + self.kv_connector, "load_failed" + ): + self.kv_connector.load_failed(seq.id) seq.status = SequenceStatus.WAITING + seq.offload_loaded = False + seq.offload_loaded_tokens = seq.num_cached_tokens + seq.offload_load_failed = True else: - skipped_waiting_requests.append(seq) - continue + waiting_remote_to_waiting_ready = self._update_waiting_for_remote_kv( + seq + ) + if waiting_remote_to_waiting_ready: + seq.status = SequenceStatus.WAITING + else: + skipped_waiting_requests.append(seq) + continue # OFFLOAD fresh-wake: a seq whose CPU/NVMe prefix just finished # loading into its GPU blocks. Unlike P/D (which jumps straight to @@ -741,9 +752,11 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: offload_resume = ( is_offload - and getattr(seq, "offload_loaded", False) + and ( + getattr(seq, "offload_loaded", False) + or getattr(seq, "offload_load_failed", False) + ) and len(seq.block_table) > 0 - and seq.num_cached_tokens > 0 ) need_to_remove_to_load_kv_async_queue = False @@ -1243,7 +1256,16 @@ def postprocess( self._partial_prefill_count -= 1 if self.kv_connector is not None: if not self.kv_connector.is_producer: - self.block_manager.deallocate(seq) + if hasattr(self.kv_connector, "should_defer_free") and ( + self.kv_connector.should_defer_free(seq) + ): + logger.debug( + "Deferring block free for seq %s until KV save completes.", + seq.id, + ) + self.deferred_free_blocks[seq.id] = seq + else: + self.block_manager.deallocate(seq) else: logger.debug( "Deferring block free for seq %s until KV send completes.", @@ -1263,6 +1285,22 @@ def _is_offload_connector(self) -> bool: """ return getattr(self.kv_connector, "is_offload", False) + @staticmethod + def _pop_req_id(req_ids: list, seq_id) -> bool: + candidates = (seq_id, str(seq_id)) + for candidate in candidates: + if candidate in req_ids: + req_ids.remove(candidate) + return True + try: + int_id = int(seq_id) + except (TypeError, ValueError): + return False + if int_id in req_ids: + req_ids.remove(int_id) + return True + return False + def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool: """Check whether a remote KV transfer for *seq* has completed. @@ -1271,10 +1309,9 @@ def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool: scheduling step. When ready, the sequence transitions back from ``WAITING_FOR_REMOTE_KVS`` to ``WAITING``. """ - if seq.id not in self.finished_recving_kv_req_ids: + if not self._pop_req_id(self.finished_recving_kv_req_ids, seq.id): return False - self.finished_recving_kv_req_ids.remove(seq.id) logger.debug("KV transfer finished for seq %s, ready for scheduling.", seq.id) return True @@ -1288,6 +1325,15 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): if kv_connector_output is None: return + def _pop_deferred(req_id): + seq = self.deferred_free_blocks.pop(req_id, None) + if seq is not None: + return seq + try: + return self.deferred_free_blocks.pop(int(req_id), None) + except (TypeError, ValueError): + return None + for req_id in kv_connector_output.finished_recving or (): assert ( not self.kv_connector.is_producer @@ -1295,15 +1341,36 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.append(req_id) + for req_id in kv_connector_output.failed_recving or (): + assert ( + not self.kv_connector.is_producer + ), "Only consumer should update failed KV recv status" + logger.warning("KV receive failed for request %s; falling back to prefill.", req_id) + self.failed_recving_kv_req_ids.append(req_id) + for req_id in kv_connector_output.finished_sending or (): assert ( self.kv_connector.is_producer ), "Only producer should free blocks after sending KV" logger.debug("Finished sending KV transfer for request %s", req_id) - assert ( - req_id in self.deferred_free_blocks - ), f"req_id={req_id} not found in deferred_free_blocks" - self.block_manager.deallocate(self.deferred_free_blocks.pop(req_id)) + seq = _pop_deferred(req_id) + assert seq is not None, f"req_id={req_id} not found in deferred_free_blocks" + self.block_manager.deallocate(seq) + + for req_id in kv_connector_output.finished_saving or (): + if hasattr(self.kv_connector, "save_finished"): + self.kv_connector.save_finished(req_id) + seq = self.deferred_free_blocks.get(req_id) + if seq is None: + try: + seq = self.deferred_free_blocks.get(int(req_id)) + except (TypeError, ValueError): + seq = None + if seq is not None and not ( + hasattr(self.kv_connector, "should_defer_free") + and self.kv_connector.should_defer_free(seq) + ): + self.block_manager.deallocate(_pop_deferred(req_id)) def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py new file mode 100644 index 0000000000..558675eca9 --- /dev/null +++ b/tests/test_lmcache_offload_connector.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations + +import threading +import sys +import types +from types import SimpleNamespace + +import pytest + +try: + import torch # noqa: F401 +except ModuleNotFoundError: + sys.modules["torch"] = types.ModuleType("torch") + +from atom.kv_transfer.disaggregation import KVConnectorOutput, KVOutputAggregator +from atom.kv_transfer.offload.connector import ( + LMCacheOffloadConnector, + LMCacheOffloadConnectorScheduler, +) +from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec +from atom.model_engine.scheduler import Scheduler + + +class _LookupClient: + def __init__(self, hit: int) -> None: + self.hit = hit + + def lookup(self, token_ids, lookup_id): + return self.hit + + +def _scheduler() -> LMCacheOffloadConnectorScheduler: + sched = LMCacheOffloadConnectorScheduler.__new__(LMCacheOffloadConnectorScheduler) + sched._config = SimpleNamespace() + sched.kv_role = "offload" + sched.block_size = 4 + sched.chunk_size = 4 + sched._lookup_client = _LookupClient(hit=0) + sched._load_specs = {} + sched._reqs_need_recv = {} + sched._save_tracker = {} + sched._save_inflight = set() + sched._lookup_in_step = [] + return sched + + +@pytest.mark.parametrize("layout", ["segment", "segment_indexed"]) +def test_segment_major_codec_roundtrip_noncontiguous_blocks(monkeypatch, layout): + import torch + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", layout) + + original = { + "l0": SimpleNamespace( + k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), + v_cache=(torch.arange(8 * 4, dtype=torch.uint8).reshape(8, 4) + 51), + k_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 101, + v_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 151, + ), + "l1": SimpleNamespace( + k_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 201), + v_cache=(torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2) + 31), + k_scale=None, + v_scale=None, + ), + } + kv_caches = { + name: SimpleNamespace( + k_cache=layer.k_cache.clone(), + v_cache=layer.v_cache.clone(), + k_scale=layer.k_scale.clone() if layer.k_scale is not None else None, + v_scale=layer.v_scale.clone() if layer.v_scale is not None else None, + ) + for name, layer in original.items() + } + + codec = ATOMKVByteCodec(kv_caches) + block_ids = [1, 2, 4, 6, 7] + host = torch.empty(len(block_ids) * codec.bytes_per_block, dtype=torch.uint8) + + codec.gpu_to_host(host, block_ids) + expected_calls = codec.segments_per_block * (3 if layout == "segment" else 2) + assert codec.copy_calls_for_block_ids(block_ids) == expected_calls + + for layer in kv_caches.values(): + layer.k_cache.zero_() + layer.v_cache.zero_() + if layer.k_scale is not None: + layer.k_scale.zero_() + if layer.v_scale is not None: + layer.v_scale.zero_() + + codec.host_to_gpu(host, block_ids) + + for name, layer in kv_caches.items(): + src = original[name] + for bid in block_ids: + assert torch.equal(layer.k_cache[bid], src.k_cache[bid]) + assert torch.equal(layer.v_cache[bid], src.v_cache[bid]) + if layer.k_scale is not None: + assert torch.equal(layer.k_scale[bid], src.k_scale[bid]) + if layer.v_scale is not None: + assert torch.equal(layer.v_scale[bid], src.v_scale[bid]) + + +def test_segment_indexed_stitches_chunk_buffers(monkeypatch): + import torch + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + kv_caches = { + "l0": SimpleNamespace( + k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), + v_cache=(torch.arange(8 * 4, dtype=torch.uint8).reshape(8, 4) + 51), + k_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 101, + v_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 151, + ), + "l1": SimpleNamespace( + k_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 201), + v_cache=(torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2) + 31), + k_scale=None, + v_scale=None, + ), + } + codec = ATOMKVByteCodec(kv_caches) + chunks = [[1, 2], [4], [6, 7]] + flat_ids = [bid for bids in chunks for bid in bids] + direct = torch.empty(len(flat_ids) * codec.bytes_per_block, dtype=torch.uint8) + codec.gpu_to_host(direct, flat_ids) + + chunk_buffers = [] + for bids in chunks: + host = torch.empty(len(bids) * codec.bytes_per_block, dtype=torch.uint8) + codec.gpu_to_host(host, bids) + chunk_buffers.append(host) + + stitched = torch.empty_like(direct) + codec.stitch_chunk_buffers(stitched, chunk_buffers, [len(b) for b in chunks]) + + assert torch.equal(stitched, direct) + + +def test_full_prompt_hit_is_clamped_before_load_spec(): + sched = _scheduler() + sched._lookup_client = _LookupClient(hit=8) + seq = SimpleNamespace( + id=123, + num_prompt_tokens=8, + token_ids=list(range(8)), + num_cached_tokens=0, + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + + assert need == 7 + assert should_park is True + assert sched._load_specs[str(seq.id)].lmcache_cached_tokens == 7 + + +def test_load_exception_is_reported_as_failed_recving(): + conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) + conn._lock = threading.Lock() + conn._done_load = set() + conn._done_save = set() + conn._failed_load = set() + req = SimpleNamespace(req_id=42) + + def boom(_req): + raise RuntimeError("load failed") + + conn._guard("load", boom, req) + + assert conn._done_load == set() + assert conn._failed_load == {42} + + +def test_aggregator_emits_failed_recving_if_any_worker_failed(): + agg = KVOutputAggregator(world_size=2) + + result = agg.aggregate( + [ + KVConnectorOutput(finished_recving={77}), + KVConnectorOutput(failed_recving={77}), + ] + ) + + assert result.finished_recving == set() + assert result.failed_recving == {77} + + +def test_aggregator_failure_overrides_late_success(): + agg = KVOutputAggregator(world_size=2) + + result = agg.aggregate( + [ + KVConnectorOutput(finished_recving={77}, failed_recving={77}), + KVConnectorOutput(finished_recving={77}), + ] + ) + + assert result.finished_recving == set() + assert result.failed_recving == {77} + assert agg.pending_count == (0, 0) + + +def test_save_inflight_defers_free_until_save_finishes(): + sched = _scheduler() + seq = SimpleNamespace( + id=9, + token_ids=list(range(8)), + block_table=[3, 4], + num_prompt_tokens=8, + prefix_hashes_published=True, + ) + sched._save_tracker[str(seq.id)] = [seq, 0] + + meta = sched.build_connector_meta() + + assert len(meta.requests) == 1 + assert meta.requests[0].save_spec is not None + assert sched.should_defer_free(seq) is True + + sched.save_finished(seq.id) + + assert sched.should_defer_free(seq) is False + + +def test_finished_saving_releases_deferred_free_with_string_req_id(): + class _BlockManager: + def __init__(self) -> None: + self.deallocated = [] + + def deallocate(self, seq) -> None: + self.deallocated.append(seq.id) + + class _Connector: + is_producer = False + + def __init__(self) -> None: + self.inflight = {"9"} + + def save_finished(self, req_id) -> None: + self.inflight.discard(str(req_id)) + + def should_defer_free(self, seq) -> bool: + return str(seq.id) in self.inflight + + sched = Scheduler.__new__(Scheduler) + sched.block_manager = _BlockManager() + sched.kv_connector = _Connector() + seq = SimpleNamespace(id=9) + sched.deferred_free_blocks = {seq.id: seq} + + sched._update_from_kv_xfer_finished(KVConnectorOutput(finished_saving={"9"})) + + assert sched.block_manager.deallocated == [9] + assert sched.deferred_free_blocks == {} + + +def test_finished_recv_matches_string_req_id(): + sched = Scheduler.__new__(Scheduler) + sched.finished_recving_kv_req_ids = ["123"] + + assert sched._update_waiting_for_remote_kv(SimpleNamespace(id=123)) is True + assert sched.finished_recving_kv_req_ids == [] From 7ec6015c846a63e15b0540b952e85b715221c8b3 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Sun, 31 May 2026 22:18:53 -0500 Subject: [PATCH 03/27] Optimize LMCache offload chunked prefill path --- atom/kv_transfer/disaggregation/aggregator.py | 23 + atom/kv_transfer/offload/connector.py | 580 ++++++++++++++++-- atom/kv_transfer/offload/gpu_connector.py | 261 +++++--- atom/kv_transfer/offload/native_stitch.cpp | 69 +++ atom/kv_transfer/offload/native_stitch.py | 9 + atom/kv_transfer/offload/trace.py | 34 + atom/model_engine/engine_core.py | 13 + atom/model_engine/model_runner.py | 38 ++ atom/model_engine/scheduler.py | 85 ++- tests/conftest.py | 13 +- tests/test_lmcache_offload_connector.py | 255 ++++++++ tests/test_scheduler.py | 36 ++ 12 files changed, 1284 insertions(+), 132 deletions(-) create mode 100644 atom/kv_transfer/offload/trace.py diff --git a/atom/kv_transfer/disaggregation/aggregator.py b/atom/kv_transfer/disaggregation/aggregator.py index ff11059b8d..3fc6c2c227 100644 --- a/atom/kv_transfer/disaggregation/aggregator.py +++ b/atom/kv_transfer/disaggregation/aggregator.py @@ -19,6 +19,7 @@ import logging from atom.kv_transfer.disaggregation.types import KVConnectorOutput, ReqId +from atom.kv_transfer.offload.trace import offload_trace logger = logging.getLogger("atom") @@ -78,9 +79,23 @@ def aggregate(self, worker_outputs: list[KVConnectorOutput]) -> KVConnectorOutpu if wo.finished_recving: for rid in wo.finished_recving: self._seen_recving.setdefault(rid, set()).add(worker_idx) + offload_trace( + "aggregator_worker_recv_done", + worker=worker_idx, + req=rid, + seen=len(self._seen_recving[rid]), + world=self._world_size, + ) if wo.failed_recving: for rid in wo.failed_recving: self._seen_recv_failed.setdefault(rid, set()).add(worker_idx) + offload_trace( + "aggregator_worker_recv_failed", + worker=worker_idx, + req=rid, + seen=len(self._seen_recv_failed[rid]), + world=self._world_size, + ) if wo.finished_saving: for rid in wo.finished_saving: self._seen_saving.setdefault(rid, set()).add(worker_idx) @@ -119,6 +134,14 @@ def aggregate(self, worker_outputs: list[KVConnectorOutput]) -> KVConnectorOutpu for rid in done_saving: del self._seen_saving[rid] + if done_recving or failed_recving or done_saving: + offload_trace( + "aggregator_done", + recv=sorted(done_recving), + failed=sorted(failed_recving), + saving=sorted(done_saving), + ) + return KVConnectorOutput( finished_sending=done_sending, finished_recving=done_recving, diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 3415cabc93..658c320fef 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -44,6 +44,7 @@ LoadSpec, SaveSpec, ) +from atom.kv_transfer.offload.trace import offload_trace logger = logging.getLogger("atom") @@ -111,6 +112,9 @@ def __init__(self, config) -> None: self._done_save: set[ReqId] = set() self._failed_load: set[ReqId] = set() self._load_active = threading.Event() + self._request_fastpath = os.environ.get( + "OFFLOAD_REQUEST_FASTPATH", "1" + ).lower() not in ("0", "false", "no", "off") self._engine = None self._sm = None @@ -176,8 +180,24 @@ def start_load_kv(self, metadata) -> None: return for req in metadata.requests: if req.load_spec is not None and self._do_load: + offload_trace( + "worker_load_enqueue", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + hbm=req.load_spec.hbm_cached_tokens, + lmc=req.load_spec.lmcache_cached_tokens, + blocks=len(req.block_ids), + ) self._load_executor.submit(self._guard, "load", self._do_load_req, req) if req.save_spec is not None and self._do_save: + offload_trace( + "worker_save_enqueue", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + skip=req.save_spec.skip_leading_tokens, + toks=len(req.token_ids), + blocks=len(req.block_ids), + ) self._save_executor.submit(self._guard, "save", self._do_save_req, req) def _guard(self, kind: str, fn, req) -> None: @@ -212,12 +232,32 @@ def _lookup_unpin(self, req_id) -> None: except Exception: pass + def _copy_device(self) -> torch.device | None: + codec = getattr(self, "_codec", None) + device = getattr(codec, "device", None) + if device is None: + return None + device = torch.device(device) + if device.type != "cuda": + return None + return device + def _stream(self) -> torch.cuda.Stream: - """A CUDA stream owned by the calling copy-daemon thread (lazily made).""" - s = getattr(self._tls, "stream", None) + """A CUDA stream owned by the calling copy-daemon thread and device.""" + device = self._copy_device() + key = str(device) if device is not None else "default" + streams = getattr(self._tls, "streams", None) + if streams is None: + streams = {} + self._tls.streams = streams + s = streams.get(key) if s is None: - s = torch.cuda.Stream() - self._tls.stream = s + if device is None: + s = torch.cuda.Stream() + else: + with torch.cuda.device(device): + s = torch.cuda.Stream() + streams[key] = s return s def _host_tmp(self, nbytes: int) -> torch.Tensor: @@ -256,19 +296,90 @@ def _profile_enabled(self) -> bool: "off", ) + def _request_fastpath_enabled(self) -> bool: + return ( + bool(getattr(self, "_request_fastpath", False)) + and self._codec is not None + and self._codec.layout == "segment_indexed" + ) + + def _request_level_key(self, chunks, token_count: int): + """Synthetic key for a whole-prefix segment-major object. + + The normal LMCache chunk keys remain authoritative for lookup. This key + is an optional per-rank fast path for exact full-prefix reloads: it uses + the last chunk's prefix hash plus tags, so it cannot collide with normal + chunk entries and stays stable across scheduler/worker processes. + """ + if not chunks: + return None + key = chunks[-1][2] + request_configs = dict(getattr(key, "request_configs", None) or {}) + request_configs["lmcache.tag.atom_offload"] = "request" + request_configs["lmcache.tag.atom_offload_tokens"] = str(int(token_count)) + request_configs["lmcache.tag.atom_offload_layout"] = str( + getattr(self._codec, "layout", "unknown") + ) + return key.__class__( + model_name=key.model_name, + world_size=key.world_size, + worker_id=key.worker_id, + chunk_hash=key.chunk_hash, + dtype=key.dtype, + request_configs=request_configs, + ) + # -- copy daemon thread ---------------------------------------------- def _do_load_req(self, req: LMCacheReqMeta) -> None: ls = req.load_spec assert ls is not None - stream = self._stream() hbm = int(ls.hbm_cached_tokens) toks = req.token_ids[: ls.lmcache_cached_tokens] t_total0 = time.perf_counter() + offload_trace( + "worker_load_start", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + hbm=hbm, + lmc=ls.lmcache_cached_tokens, + toks=len(toks), + blocks=len(req.block_ids), + ) if int(ls.lmcache_cached_tokens) <= hbm: self._lookup_unpin(req.req_id) with self._lock: self._done_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="hbm_only", + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + ) + return + chunk_size = int(self.chunk_size or 256) + if hbm % chunk_size != 0: + logger.warning( + "LMCache offload: HBM prefix is not chunk-aligned req=%s " + "hbm=%d chunk=%d; re-prefill", + req.req_id, + hbm, + chunk_size, + ) + self._lookup_unpin(req.req_id) + with self._lock: + self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="unaligned_hbm", + hbm=hbm, + chunk=chunk_size, + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + ) return + stream = self._stream() mask = torch.ones(len(toks), dtype=torch.bool) mask[:hbm] = False t0 = time.perf_counter() @@ -287,6 +398,13 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: self._lookup_unpin(req.req_id) with self._lock: self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="no_chunks", + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + ) return for (s, _e, _key) in chunks: if s < hbm: @@ -298,8 +416,103 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: self._lookup_unpin(req.req_id) with self._lock: self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="overlap_hbm", + hbm=hbm, + chunk_start=s, + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + ) return contains_ms = 0.0 + loaded_objs = [] + get_ms = 0.0 + host_alloc_ms = 0.0 + stitch_ms = 0.0 + h2d_submit_ms = 0.0 + sync_ms = 0.0 + nblocks = 0 + nbytes = 0 + copy_calls = 0 + chunk_bids: list[list[int]] = [ + self._block_ids(req, s, e) for (s, e, _key) in chunks + ] + all_bids = [bid for bids in chunk_bids for bid in bids] + nblocks = len(all_bids) + nbytes = nblocks * self._codec.bytes_per_block + + request_key = None + if hbm == 0 and self._request_fastpath_enabled(): + request_key = self._request_level_key(chunks, len(toks)) + if request_key is not None: + req_mo = None + t0 = time.perf_counter() + request_location = self._sm.contains(request_key) + contains_ms += (time.perf_counter() - t0) * 1000 + if request_location: + try: + t0 = time.perf_counter() + req_mo = self._sm.get(request_key) + get_ms += (time.perf_counter() - t0) * 1000 + if req_mo is not None: + copy_calls = self._codec.copy_calls_for_block_ids(all_bids) + t0 = time.perf_counter() + self._codec.host_to_gpu(req_mo.tensor, all_bids, stream) + h2d_submit_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + self._lookup_unpin(req.req_id) + with self._lock: + self._done_load.add(req.req_id) + total_ms = (time.perf_counter() - t_total0) * 1000 + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="ok_request", + chunks=len(chunks), + blocks=nblocks, + bytes_gib=f"{nbytes / 1024**3:.3f}", + stitch_ms=f"{stitch_ms:.2f}", + h2d_submit_ms=f"{h2d_submit_ms:.2f}", + sync_ms=f"{sync_ms:.2f}", + total_ms=f"{total_ms:.2f}", + ) + if self._profile_enabled(): + logger.info( + "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " + "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " + "layout=%s fastpath=request process_ms=%.2f " + "contains_ms=%.2f get_ms=%.2f host_alloc_ms=%.2f " + "stitch_ms=%.2f h2d_submit_ms=%.2f sync_ms=%.2f " + "total_ms=%.2f", + getattr(self, "_rank", "?"), + req.req_id, + hbm, + ls.lmcache_cached_tokens, + len(chunks), + nblocks, + nbytes / 1024**3, + copy_calls, + self._codec.layout, + process_ms, + contains_ms, + get_ms, + host_alloc_ms, + stitch_ms, + h2d_submit_ms, + sync_ms, + total_ms, + ) + logger.info("offload _do_load DONE req=%s", req.req_id) + return + finally: + if req_mo is not None: + req_mo.ref_count_down() + for (_s, _e, key) in chunks: t0 = time.perf_counter() if not self._sm.contains(key): @@ -308,19 +521,17 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: self._lookup_unpin(req.req_id) with self._lock: self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="miss", + chunks=len(chunks), + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + ) return contains_ms += (time.perf_counter() - t0) * 1000 - loaded_objs = [] - get_ms = 0.0 - host_alloc_ms = 0.0 - stitch_ms = 0.0 - h2d_submit_ms = 0.0 - sync_ms = 0.0 - nblocks = 0 - nbytes = 0 - copy_calls = 0 - chunk_bids: list[list[int]] = [] try: for (s, e, key) in chunks: t0 = time.perf_counter() @@ -335,19 +546,23 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: self._lookup_unpin(req.req_id) with self._lock: self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="get_none", + chunks=len(chunks), + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + ) return loaded_objs.append(mo) - bids = self._block_ids(req, s, e) - chunk_bids.append(bids) - nblocks += len(bids) - nbytes += len(bids) * self._codec.bytes_per_block + bids = chunk_bids[len(loaded_objs) - 1] if self._codec.layout != "segment_indexed": copy_calls += self._codec.copy_calls_for_block_ids(bids) t0 = time.perf_counter() self._codec.host_to_gpu(mo.tensor, bids, stream) h2d_submit_ms += (time.perf_counter() - t0) * 1000 if self._codec.layout == "segment_indexed": - all_bids = [bid for bids in chunk_bids for bid in bids] copy_calls = self._codec.copy_calls_for_block_ids(all_bids) t0 = time.perf_counter() req_buf = self._host_tmp(nbytes) @@ -382,14 +597,27 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: self._lookup_unpin(req.req_id) with self._lock: self._done_load.add(req.req_id) + total_ms = (time.perf_counter() - t_total0) * 1000 + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="ok", + chunks=len(chunks), + blocks=nblocks, + bytes_gib=f"{nbytes / 1024**3:.3f}", + stitch_ms=f"{stitch_ms:.2f}", + h2d_submit_ms=f"{h2d_submit_ms:.2f}", + sync_ms=f"{sync_ms:.2f}", + total_ms=f"{total_ms:.2f}", + ) if self._profile_enabled(): - total_ms = (time.perf_counter() - t_total0) * 1000 logger.info( "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " - "layout=%s process_ms=%.2f contains_ms=%.2f get_ms=%.2f " - "host_alloc_ms=%.2f stitch_ms=%.2f h2d_submit_ms=%.2f " - "sync_ms=%.2f total_ms=%.2f", + "layout=%s fastpath=chunk process_ms=%.2f contains_ms=%.2f " + "get_ms=%.2f host_alloc_ms=%.2f stitch_ms=%.2f " + "h2d_submit_ms=%.2f sync_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, hbm, @@ -423,9 +651,24 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: if skip >= len(toks): with self._lock: self._done_save.add(req.req_id) + offload_trace( + "worker_save_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="skip", + toks=len(toks), + ) return t_total0 = time.perf_counter() + offload_trace( + "worker_save_start", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + skip=skip, + toks=len(toks), + blocks=len(req.block_ids), + ) mask = torch.ones(len(toks), dtype=torch.bool) mask[:skip] = False t0 = time.perf_counter() @@ -433,15 +676,30 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: process_ms = (time.perf_counter() - t0) * 1000 keys, objs, already = [], [], 0 + request_key = None + request_obj = None + request_fastpath = "off" + if skip == 0 and self._request_fastpath_enabled() and chunks: + request_key = self._request_level_key(chunks, len(toks)) + request_fastpath = "miss" + t0 = time.perf_counter() + if self._sm.contains(request_key): + request_key = None + request_fastpath = "hit" + contains_ms = (time.perf_counter() - t0) * 1000 + else: + contains_ms = 0.0 put_started = False - contains_ms = 0.0 alloc_ms = 0.0 + host_alloc_ms = 0.0 d2h_submit_ms = 0.0 sync_ms = 0.0 + split_ms = 0.0 put_ms = 0.0 nblocks = 0 total_nbytes = 0 copy_calls = 0 + chunk_bids: list[list[int]] = [] try: for (s, e, key) in chunks: self._pause_save_for_load(stream) @@ -461,21 +719,64 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: break keys.append(key) objs.append(mo) + chunk_bids.append(bids) nblocks += len(bids) total_nbytes += chunk_nbytes - copy_calls += self._codec.copy_calls_for_block_ids(bids) - # D2H on this thread's dedicated copy stream (off the compute stream). - t0 = time.perf_counter() - self._codec.gpu_to_host(mo.tensor, bids, stream) - d2h_submit_ms += (time.perf_counter() - t0) * 1000 + if self._codec.layout != "segment_indexed": + copy_calls += self._codec.copy_calls_for_block_ids(bids) + # D2H on this thread's dedicated copy stream (off compute stream). + t0 = time.perf_counter() + self._codec.gpu_to_host(mo.tensor, bids, stream) + d2h_submit_ms += (time.perf_counter() - t0) * 1000 if keys: + if self._codec.layout == "segment_indexed": + all_bids = [bid for bids in chunk_bids for bid in bids] + copy_calls = self._codec.copy_calls_for_block_ids(all_bids) + if request_key is not None and len(keys) == len(chunks): + t0 = time.perf_counter() + request_obj = self._sm.allocate( + torch.Size((total_nbytes,)), + torch.uint8, + fmt=MemoryFormat.KV_2LTD, + ) + alloc_ms += (time.perf_counter() - t0) * 1000 + if request_obj is not None: + req_buf = request_obj.tensor + request_fastpath = "stored" + else: + request_key = None + request_fastpath = "alloc_failed" + else: + request_key = None + if request_fastpath == "miss": + request_fastpath = "partial_skip" + if request_obj is None: + t0 = time.perf_counter() + req_buf = self._host_tmp(total_nbytes) + host_alloc_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self._codec.gpu_to_host(req_buf, all_bids, stream) + d2h_submit_ms += (time.perf_counter() - t0) * 1000 t0 = time.perf_counter() stream.synchronize() # stream-specific sync_ms += (time.perf_counter() - t0) * 1000 + if self._codec.layout == "segment_indexed": + t0 = time.perf_counter() + self._codec.split_request_buffer( + req_buf, + [mo.tensor for mo in objs], + [len(bids) for bids in chunk_bids], + ) + split_ms += (time.perf_counter() - t0) * 1000 put_started = True t0 = time.perf_counter() - self._sm.batched_put(keys, objs) + put_keys = list(keys) + put_objs = list(objs) + if request_key is not None and request_obj is not None: + put_keys.append(request_key) + put_objs.append(request_obj) + self._sm.batched_put(put_keys, put_objs) put_ms += (time.perf_counter() - t0) * 1000 except Exception: if not put_started: @@ -484,18 +785,39 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: stream.synchronize() sync_ms += (time.perf_counter() - t0) * 1000 finally: - for mo in objs: + cleanup_objs = list(objs) + if request_obj is not None: + cleanup_objs.append(request_obj) + for mo in cleanup_objs: mo.ref_count_down() raise with self._lock: self._done_save.add(req.req_id) + total_ms = (time.perf_counter() - t_total0) * 1000 + offload_trace( + "worker_save_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="ok", + toks=len(toks), + chunks=len(chunks), + stored=len(keys), + blocks=nblocks, + bytes_gib=f"{total_nbytes / 1024**3:.3f}", + d2h_submit_ms=f"{d2h_submit_ms:.2f}", + sync_ms=f"{sync_ms:.2f}", + split_ms=f"{split_ms:.2f}", + request_fastpath=request_fastpath, + total_ms=f"{total_ms:.2f}", + ) if self._profile_enabled(): - total_ms = (time.perf_counter() - t_total0) * 1000 logger.info( "[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d chunks=%d " "stored=%d already=%d blocks=%d bytes=%.3fGiB copy_calls=%d " - "layout=%s process_ms=%.2f contains_ms=%.2f alloc_ms=%.2f " - "d2h_submit_ms=%.2f sync_ms=%.2f put_ms=%.2f total_ms=%.2f", + "layout=%s request_fastpath=%s process_ms=%.2f " + "contains_ms=%.2f alloc_ms=%.2f host_alloc_ms=%.2f " + "d2h_submit_ms=%.2f sync_ms=%.2f split_ms=%.2f " + "put_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, len(toks), @@ -506,11 +828,14 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: total_nbytes / 1024**3, copy_calls, self._codec.layout, + request_fastpath, process_ms, contains_ms, alloc_ms, + host_alloc_ms, d2h_submit_ms, sync_ms, + split_ms, put_ms, total_ms, ) @@ -535,6 +860,14 @@ def get_finished(self) -> KVConnectorOutput: self._done_save.clear() self._done_load.clear() self._failed_load.clear() + if dl or fl or ds: + offload_trace( + "worker_get_finished", + rank=getattr(self, "_rank", "?"), + done_load=sorted(dl), + failed_load=sorted(fl), + done_save=sorted(ds), + ) return KVConnectorOutput( finished_sending=set(), finished_recving=dl, @@ -595,11 +928,25 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: return 0, False num_prompt = seq.num_prompt_tokens token_ids = list(seq.token_ids[:num_prompt]) + offload_trace( + "scheduler_lookup_start", + req=seq.id, + prompt=num_prompt, + hbm=seq.num_cached_tokens, + ) + t_lookup0 = time.perf_counter() try: hit = self._lookup_client.lookup(token_ids, lookup_id=str(seq.id)) except Exception: logger.exception("LMCache offload lookup failed for seq %s", seq.id) + offload_trace( + "scheduler_lookup_done", + req=seq.id, + status="exception", + lookup_ms=f"{(time.perf_counter() - t_lookup0) * 1000:.2f}", + ) return 0, False + lookup_ms = (time.perf_counter() - t_lookup0) * 1000 if logger.isEnabledFor(logging.DEBUG): _lh = None try: @@ -612,6 +959,13 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: logger.debug("[OFFLOAD-LOOKUP] seq=%s num_prompt=%d hbm_cached=%d hit=%s lookuphash3=%s", seq.id, num_prompt, int(seq.num_cached_tokens), hit, _lh) if not hit: + offload_trace( + "scheduler_lookup_done", + req=seq.id, + status="miss", + hit=hit, + lookup_ms=f"{lookup_ms:.2f}", + ) return 0, False sid = str(seq.id) hit = int(hit) @@ -624,6 +978,14 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: self._lookup_client.clear_lookup_status(sid) except Exception: pass + offload_trace( + "scheduler_lookup_done", + req=seq.id, + status="hbm_satisfies", + hit=hit, + hbm=seq.num_cached_tokens, + lookup_ms=f"{lookup_ms:.2f}", + ) return 0, False self._lookup_in_step.append(sid) self._load_specs[sid] = LoadSpec( @@ -631,6 +993,15 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: lmcache_cached_tokens=hit, can_load=False, ) + offload_trace( + "scheduler_lookup_done", + req=seq.id, + status="need_load", + hit=hit, + hbm=seq.num_cached_tokens, + need=need, + lookup_ms=f"{lookup_ms:.2f}", + ) return need, True # True => park in WAITING_FOR_REMOTE_KVS def update_state_after_alloc(self, seq) -> None: @@ -641,8 +1012,15 @@ def update_state_after_alloc(self, seq) -> None: if ls is not None: ls.can_load = True self._reqs_need_recv[sid] = seq - # Track for save; the prompt prefix is offloaded later, once prefill has - # actually computed it (checked via prefix_hashes_published in build). + offload_trace( + "scheduler_load_alloc_ready", + req=seq.id, + hbm=seq.num_cached_tokens, + lmc=ls.lmcache_cached_tokens, + blocks=len(seq.block_table), + ) + # Track for save; build_connector_meta stores chunks once the scheduler's + # computed frontier (seq.num_cached_tokens) has advanced past them. if sid not in self._save_tracker: self._save_tracker[sid] = [seq, 0] @@ -667,6 +1045,67 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: # prefix-cache blocks (possibly shared with other seqs) -> output # corruption. So load only [hbm_cached, offload_hit). ls.hbm_cached_tokens = int(seq.num_cached_tokens) + if ls.hbm_cached_tokens >= int(ls.lmcache_cached_tokens): + seq.offload_loaded_tokens = int(seq.num_cached_tokens) + logger.info( + "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " + "reason=hbm_satisfies_after_alloc", + seq.id, + ls.hbm_cached_tokens, + ls.lmcache_cached_tokens, + ) + offload_trace( + "scheduler_load_hbm_satisfies_after_alloc", + req=seq.id, + hbm=ls.hbm_cached_tokens, + lmc=ls.lmcache_cached_tokens, + blocks=len(list(seq.block_table)), + ) + # The request may already be parked in WAITING_FOR_REMOTE_KVS. + # Emit a no-op load so every worker reports finished_recving via + # the normal aggregation path instead of trying to complete it + # locally in the scheduler process. + meta.add_request(LMCacheReqMeta( + req_id=seq.id, + token_ids=list(seq.token_ids[: ls.lmcache_cached_tokens]), + block_ids=list(seq.block_table), + load_spec=ls, + )) + continue + chunk = self.chunk_size or 256 + if ls.hbm_cached_tokens % chunk != 0: + seq.offload_loaded_tokens = int(seq.num_cached_tokens) + logger.info( + "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " + "reason=unaligned_hbm chunk=%d", + seq.id, + ls.hbm_cached_tokens, + ls.lmcache_cached_tokens, + chunk, + ) + offload_trace( + "scheduler_load_unaligned_hbm", + req=seq.id, + hbm=ls.hbm_cached_tokens, + lmc=ls.lmcache_cached_tokens, + chunk=chunk, + blocks=len(list(seq.block_table)), + ) + # LMCache chunks can only be loaded from a chunk boundary. Do + # not round down and overwrite HBM prefix-cache blocks that may + # be shared with other requests; wake the parked request and let + # it continue prefill from the HBM floor. + meta.add_request(LMCacheReqMeta( + req_id=seq.id, + token_ids=list(seq.token_ids[: ls.hbm_cached_tokens]), + block_ids=list(seq.block_table), + load_spec=LoadSpec( + hbm_cached_tokens=ls.hbm_cached_tokens, + lmcache_cached_tokens=ls.hbm_cached_tokens, + can_load=True, + ), + )) + continue # num_cached after load = max(HBM, offload); never drop below HBM. seq.offload_loaded_tokens = max( int(seq.num_cached_tokens), int(ls.lmcache_cached_tokens) @@ -677,41 +1116,85 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: logger.info("[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d offload_loaded=%d nblocks=%d", seq.id, ls.hbm_cached_tokens, ls.lmcache_cached_tokens, seq.offload_loaded_tokens, len(list(seq.block_table))) + offload_trace( + "scheduler_load_emit", + req=seq.id, + hbm=ls.hbm_cached_tokens, + lmc=ls.lmcache_cached_tokens, + offload_loaded=seq.offload_loaded_tokens, + blocks=len(list(seq.block_table)), + ) meta.add_request(LMCacheReqMeta( req_id=seq.id, token_ids=list(seq.token_ids[: ls.lmcache_cached_tokens]), block_ids=list(seq.block_table), load_spec=ls, )) - # Saves: store the prompt prefix once prefill has computed it. We detect - # "computed" via seq.prefix_hashes_published (set in postprocess after the - # prefill step), so the blocks we D2H are already written -- no race with - # forward. Persistent tracker: each chunk is stored once. + # Saves: store fully computed prompt chunks. Under scheduler-side + # chunked prefill, seq.num_cached_tokens advances after each prefill + # chunk's forward has completed; use it as the D2H-safe frontier. chunk = self.chunk_size or 256 for sid, entry in self._save_tracker.items(): seq, saved = entry if sid in self._reqs_need_recv: continue # loading this step; defer its save - if not getattr(seq, "prefix_hashes_published", False): - continue # prefill not finished computing the prompt yet - aligned = (int(seq.num_prompt_tokens) // chunk) * chunk + if sid in self._save_inflight: + continue # keep at most one save per request in flight + computed = min( + int(getattr(seq, "num_cached_tokens", 0)), + int(seq.num_prompt_tokens), + ) + is_last_prefill = computed >= int(seq.num_prompt_tokens) + aligned = (computed // chunk) * chunk if aligned <= saved: continue - logger.debug("[OFFLOAD-SAVE-EMIT] seq=%s num_prompt=%d aligned=%d saved=%d", - seq.id, int(seq.num_prompt_tokens), aligned, saved) + logger.debug( + "[OFFLOAD-SAVE-EMIT] seq=%s computed=%d num_prompt=%d aligned=%d saved=%d", + seq.id, + computed, + int(seq.num_prompt_tokens), + aligned, + saved, + ) + offload_trace( + "scheduler_save_emit", + req=seq.id, + prompt=seq.num_prompt_tokens, + computed=computed, + aligned=aligned, + saved=saved, + blocks=len(seq.block_table), + ) meta.add_request(LMCacheReqMeta( req_id=seq.id, token_ids=list(seq.token_ids[:aligned]), block_ids=list(seq.block_table), save_spec=SaveSpec(skip_leading_tokens=saved, can_save=True), + is_last_prefill=is_last_prefill, )) entry[1] = aligned self._save_inflight.add(sid) self._reqs_need_recv.clear() return meta + def _save_frontier(self, seq) -> int: + chunk = self.chunk_size or 256 + computed = min( + int(getattr(seq, "num_cached_tokens", 0)), + int(getattr(seq, "num_prompt_tokens", 0)), + ) + return (computed // chunk) * chunk + + def _has_pending_save(self, seq) -> bool: + sid = str(seq.id) + entry = self._save_tracker.get(sid) + if entry is None: + return False + return self._save_frontier(seq) > int(entry[1]) + def should_defer_free(self, seq) -> bool: - return str(seq.id) in self._save_inflight + sid = str(seq.id) + return sid in self._save_inflight or self._has_pending_save(seq) def save_finished(self, req_id) -> None: self._save_inflight.discard(str(req_id)) @@ -724,7 +1207,8 @@ def request_finished(self, seq) -> None: sid = str(seq.id) self._load_specs.pop(sid, None) self._reqs_need_recv.pop(sid, None) - self._save_tracker.pop(sid, None) + if not self.should_defer_free(seq): + self._save_tracker.pop(sid, None) if self._lookup_client is not None: try: self._lookup_client.clear_lookup_status(sid) diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 321996a353..9f4009bbee 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -29,6 +29,7 @@ from __future__ import annotations import logging +import operator import os import threading @@ -60,6 +61,19 @@ def __init__(self, kv_caches: dict) -> None: if not self._segments: raise ValueError("ATOMKVByteCodec: no movable KV tensors registered") + first = self._segments[0] + self.num_blocks: int = int(first.shape[0]) + self._device = first.device + for seg in self._segments: + if seg.device != self._device: + raise ValueError( + "ATOMKVByteCodec: all KV tensors must be on the same device" + ) + if int(seg.shape[0]) != self.num_blocks: + raise ValueError( + "ATOMKVByteCodec: all KV tensors must have the same block count" + ) + # Bytes for one block of each segment (block is dim 0). self._seg_block_bytes: list[int] = [ int(t[0].numel()) * t.element_size() for t in self._segments @@ -71,12 +85,12 @@ def __init__(self, kv_caches: dict) -> None: self._seg_off.append(acc) acc += nb self.bytes_per_block: int = acc - self.num_blocks: int = int(self._segments[0].shape[0]) self.layout = os.environ.get("OFFLOAD_CODEC_LAYOUT", "block").lower() if self.layout not in ("block", "segment", "segment_indexed"): self.layout = "block" self._tls = threading.local() self._native_stitch = None + self._native_split = None if ( self.layout == "segment_indexed" and os.environ.get("OFFLOAD_NATIVE_STITCH", "0").lower() @@ -87,6 +101,7 @@ def __init__(self, kv_caches: dict) -> None: native_stitch.load_extension() self._native_stitch = native_stitch.stitch_chunk_buffers + self._native_split = native_stitch.split_request_buffer except Exception: logger.warning( "ATOMKVByteCodec: native stitch unavailable; using torch stitch", @@ -97,6 +112,10 @@ def __init__(self, kv_caches: dict) -> None: def segments_per_block(self) -> int: return len(self._segments) + @property + def device(self) -> torch.device: + return self._device + def copy_calls_for_blocks(self, nblocks: int) -> int: return int(nblocks) * len(self._segments) @@ -160,6 +179,38 @@ def _segment_bases(self, nblocks: int) -> list[int]: acc += nb * nblocks return bases + def _device_ctx(self): + if self._device.type == "cuda": + return torch.cuda.device(self._device) + return _nullctx() + + def _normalize_block_ids(self, block_ids: list[int]) -> list[int]: + try: + normalized = [operator.index(bid) for bid in block_ids] + except TypeError as exc: + raise ValueError("ATOMKVByteCodec: block_ids must be integers") from exc + if not normalized: + return normalized + min_bid = min(normalized) + max_bid = max(normalized) + if min_bid < 0 or max_bid >= self.num_blocks: + raise ValueError( + "ATOMKVByteCodec: block id out of range " + f"[0, {self.num_blocks}); min={min_bid} max={max_bid}" + ) + return normalized + + def _validate_host_buf(self, host_buf: torch.Tensor, nblocks: int) -> None: + if host_buf.dtype != torch.uint8: + raise TypeError("ATOMKVByteCodec: host_buf must be a uint8 tensor") + required = int(nblocks) * self.bytes_per_block + if int(host_buf.numel()) < required: + raise ValueError( + "ATOMKVByteCodec: host_buf is too small " + f"for {nblocks} blocks; need {required} bytes, " + f"got {int(host_buf.numel())}" + ) + def stitch_chunk_buffers( self, dst: torch.Tensor, @@ -202,6 +253,53 @@ def stitch_chunk_buffers( out=dst[dst_base : dst_base + total_blocks * nb], ) + def split_request_buffer( + self, + src: torch.Tensor, + chunk_buffers: list[torch.Tensor], + chunk_block_counts: list[int], + ) -> None: + """CPU-side inverse of :meth:`stitch_chunk_buffers`. + + ``src`` is one request-level segment-major buffer + ``[seg0 all_blocks | seg1 all_blocks | ...]``. Each destination chunk + receives its own segment-major slice + ``[seg0 chunk_blocks | seg1 chunk_blocks | ...]`` for LMCache storage. + """ + if self._native_split is not None: + self._native_split( + src, + chunk_buffers, + chunk_block_counts, + self._seg_block_bytes, + ) + return + total_blocks = sum(chunk_block_counts) + src_bases = self._segment_bases(total_blocks) + dst_bases_by_chunk = [ + self._segment_bases(nblocks) for nblocks in chunk_block_counts + ] + for seg_idx, (src_base, nb) in enumerate( + zip(src_bases, self._seg_block_bytes) + ): + logical_block_start = 0 + for dst, bases, nblocks in zip( + chunk_buffers, dst_bases_by_chunk, chunk_block_counts + ): + nbytes = nblocks * nb + if nbytes: + dst[ + bases[seg_idx] : bases[seg_idx] + nbytes + ].copy_( + src[ + src_base + + logical_block_start * nb : src_base + + logical_block_start * nb + + nbytes + ] + ) + logical_block_start += nblocks + def _tmp_bytes(self, seg: torch.Tensor, nblocks: int) -> torch.Tensor: elems = int(seg[0].numel()) * seg.element_size() key = (str(seg.device), "uint8", elems, int(nblocks)) @@ -230,45 +328,52 @@ def gpu_to_host( ) -> None: """D2H: gather ``block_ids`` from the paged GPU cache into the flat pinned ``host_buf`` (uint8, length == len(block_ids) * bytes_per_block).""" - ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() - with ctx: - if self.layout == "segment_indexed": - idx = torch.tensor( - block_ids, dtype=torch.long, device=self._segments[0].device - ) - bases = self._segment_bases(len(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - mat = self._segment_bytes_matrix(seg) - tmp = self._tmp_bytes(seg, len(block_ids)) - torch.index_select(mat, 0, idx, out=tmp) - host_buf[base : base + len(block_ids) * nb].copy_( - tmp.reshape(-1), non_blocking=True + block_ids = self._normalize_block_ids(block_ids) + self._validate_host_buf(host_buf, len(block_ids)) + if not block_ids: + return + with self._device_ctx(): + stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with stream_ctx: + if self.layout == "segment_indexed": + idx = torch.tensor( + block_ids, dtype=torch.long, device=self._device ) - return - - if self.layout == "segment": - bases = self._segment_bases(len(block_ids)) - runs = list(self._contiguous_runs(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - for logical_start, physical_start, run_len in runs: - src = self._blocks_bytes_view(seg, physical_start, run_len) - dst = base + logical_start * nb - host_buf[dst : dst + run_len * nb].copy_( + bases = self._segment_bases(len(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + mat = self._segment_bytes_matrix(seg) + tmp = self._tmp_bytes(seg, len(block_ids)) + torch.index_select(mat, 0, idx, out=tmp) + host_buf[base : base + len(block_ids) * nb].copy_( + tmp.reshape(-1), non_blocking=True + ) + return + + if self.layout == "segment": + bases = self._segment_bases(len(block_ids)) + runs = list(self._contiguous_runs(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + for logical_start, physical_start, run_len in runs: + src = self._blocks_bytes_view(seg, physical_start, run_len) + dst = base + logical_start * nb + host_buf[dst : dst + run_len * nb].copy_( + src, non_blocking=True + ) + return + + for i, bid in enumerate(block_ids): + base = i * self.bytes_per_block + for seg, off, nb in zip( + self._segments, self._seg_off, self._seg_block_bytes + ): + src = self._block_bytes_view(seg, bid) + host_buf[base + off : base + off + nb].copy_( src, non_blocking=True ) - return - - for i, bid in enumerate(block_ids): - base = i * self.bytes_per_block - for seg, off, nb in zip( - self._segments, self._seg_off, self._seg_block_bytes - ): - src = self._block_bytes_view(seg, bid) - host_buf[base + off : base + off + nb].copy_(src, non_blocking=True) def host_to_gpu( self, @@ -278,47 +383,55 @@ def host_to_gpu( ) -> None: """H2D: scatter the flat pinned ``host_buf`` back into the paged GPU cache at ``block_ids`` (in-place into the real KV tensors).""" - ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() - with ctx: - if self.layout == "segment_indexed": - idx = torch.tensor( - block_ids, dtype=torch.long, device=self._segments[0].device - ) - bases = self._segment_bases(len(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - mat = self._segment_bytes_matrix(seg) - tmp = self._tmp_bytes(seg, len(block_ids)) - tmp.copy_( - host_buf[base : base + len(block_ids) * nb].reshape_as(tmp), - non_blocking=True, + block_ids = self._normalize_block_ids(block_ids) + self._validate_host_buf(host_buf, len(block_ids)) + if not block_ids: + return + with self._device_ctx(): + stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with stream_ctx: + if self.layout == "segment_indexed": + idx = torch.tensor( + block_ids, dtype=torch.long, device=self._device ) - mat.index_copy_(0, idx, tmp) - return - - if self.layout == "segment": - bases = self._segment_bases(len(block_ids)) - runs = list(self._contiguous_runs(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - for logical_start, physical_start, run_len in runs: - dst = self._blocks_bytes_view(seg, physical_start, run_len) - src = base + logical_start * nb + bases = self._segment_bases(len(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + mat = self._segment_bytes_matrix(seg) + tmp = self._tmp_bytes(seg, len(block_ids)) + tmp.copy_( + host_buf[base : base + len(block_ids) * nb].reshape_as(tmp), + non_blocking=True, + ) + mat.index_copy_(0, idx, tmp) + return + + if self.layout == "segment": + bases = self._segment_bases(len(block_ids)) + runs = list(self._contiguous_runs(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + for logical_start, physical_start, run_len in runs: + dst = self._blocks_bytes_view(seg, physical_start, run_len) + src = base + logical_start * nb + dst.copy_( + host_buf[src : src + run_len * nb], + non_blocking=True, + ) + return + + for i, bid in enumerate(block_ids): + base = i * self.bytes_per_block + for seg, off, nb in zip( + self._segments, self._seg_off, self._seg_block_bytes + ): + dst = self._block_bytes_view(seg, bid) dst.copy_( - host_buf[src : src + run_len * nb], + host_buf[base + off : base + off + nb], non_blocking=True, ) - return - - for i, bid in enumerate(block_ids): - base = i * self.bytes_per_block - for seg, off, nb in zip( - self._segments, self._seg_off, self._seg_block_bytes - ): - dst = self._block_bytes_view(seg, bid) - dst.copy_(host_buf[base + off : base + off + nb], non_blocking=True) class _nullctx: diff --git a/atom/kv_transfer/offload/native_stitch.cpp b/atom/kv_transfer/offload/native_stitch.cpp index b81022e354..6e76ca425c 100644 --- a/atom/kv_transfer/offload/native_stitch.cpp +++ b/atom/kv_transfer/offload/native_stitch.cpp @@ -76,6 +76,75 @@ void stitch_chunk_buffers( }); } +void split_request_buffer( + torch::Tensor src, + std::vector chunk_buffers, + std::vector chunk_block_counts, + std::vector seg_block_bytes) { + TORCH_CHECK(src.device().is_cpu(), "src must be a CPU tensor"); + TORCH_CHECK(src.dtype() == torch::kUInt8, "src must be uint8"); + TORCH_CHECK(src.is_contiguous(), "src must be contiguous"); + TORCH_CHECK( + chunk_buffers.size() == chunk_block_counts.size(), + "chunk_buffers and chunk_block_counts size mismatch"); + + const int64_t nchunks = static_cast(chunk_buffers.size()); + const int64_t nsegs = static_cast(seg_block_bytes.size()); + int64_t total_blocks = 0; + for (int64_t nblocks : chunk_block_counts) { + TORCH_CHECK(nblocks >= 0, "chunk block count must be non-negative"); + total_blocks += nblocks; + } + + std::vector src_bases(nsegs); + int64_t acc = 0; + for (int64_t seg = 0; seg < nsegs; ++seg) { + const int64_t nb = seg_block_bytes[seg]; + TORCH_CHECK(nb >= 0, "segment byte count must be non-negative"); + src_bases[seg] = acc; + acc += nb * total_blocks; + } + TORCH_CHECK(src.numel() >= acc, "src is smaller than split input"); + + std::vector dst_ptrs(nchunks); + std::vector dst_offsets(nchunks * nsegs); + for (int64_t c = 0; c < nchunks; ++c) { + auto& dst = chunk_buffers[c]; + TORCH_CHECK(dst.device().is_cpu(), "chunk buffer must be a CPU tensor"); + TORCH_CHECK(dst.dtype() == torch::kUInt8, "chunk buffer must be uint8"); + TORCH_CHECK(dst.is_contiguous(), "chunk buffer must be contiguous"); + dst_ptrs[c] = dst.data_ptr(); + + int64_t dst_acc = 0; + const int64_t nblocks = chunk_block_counts[c]; + for (int64_t seg = 0; seg < nsegs; ++seg) { + dst_offsets[c * nsegs + seg] = dst_acc; + dst_acc += seg_block_bytes[seg] * nblocks; + } + TORCH_CHECK(dst.numel() >= dst_acc, "chunk buffer is smaller than expected"); + } + + const auto* src_ptr = src.data_ptr(); + at::parallel_for(0, nsegs, 1, [&](int64_t begin, int64_t end) { + for (int64_t seg = begin; seg < end; ++seg) { + const int64_t nb = seg_block_bytes[seg]; + int64_t logical_block_start = 0; + for (int64_t c = 0; c < nchunks; ++c) { + const int64_t nblocks = chunk_block_counts[c]; + const int64_t nbytes = nblocks * nb; + if (nbytes > 0) { + std::memcpy( + dst_ptrs[c] + dst_offsets[c * nsegs + seg], + src_ptr + src_bases[seg] + logical_block_start * nb, + static_cast(nbytes)); + } + logical_block_start += nblocks; + } + } + }); +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("stitch_chunk_buffers", &stitch_chunk_buffers); + m.def("split_request_buffer", &split_request_buffer); } diff --git a/atom/kv_transfer/offload/native_stitch.py b/atom/kv_transfer/offload/native_stitch.py index c9c2ac5e70..5a75c20522 100644 --- a/atom/kv_transfer/offload/native_stitch.py +++ b/atom/kv_transfer/offload/native_stitch.py @@ -42,3 +42,12 @@ def stitch_chunk_buffers(dst, chunk_buffers, chunk_block_counts, seg_block_bytes [int(x) for x in chunk_block_counts], [int(x) for x in seg_block_bytes], ) + + +def split_request_buffer(src, chunk_buffers, chunk_block_counts, seg_block_bytes) -> None: + _load_ext().split_request_buffer( + src, + chunk_buffers, + [int(x) for x in chunk_block_counts], + [int(x) for x in seg_block_bytes], + ) diff --git a/atom/kv_transfer/offload/trace.py b/atom/kv_transfer/offload/trace.py new file mode 100644 index 0000000000..0d0d16bed5 --- /dev/null +++ b/atom/kv_transfer/offload/trace.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations + +import logging +import os +import time + +logger = logging.getLogger("atom") +_START = time.perf_counter() + + +def offload_trace_enabled() -> bool: + return os.environ.get("OFFLOAD_TRACE_E2E", "0").lower() not in ( + "0", + "false", + "no", + "off", + ) + + +def offload_trace(event: str, **fields) -> None: + if not offload_trace_enabled(): + return + now = time.perf_counter() + parts = [f"{key}={value}" for key, value in fields.items()] + logger.info( + "[OFFLOAD-TRACE] t=%.6f dt_ms=%.3f event=%s %s", + now, + (now - _START) * 1000, + event, + " ".join(parts), + ) diff --git a/atom/model_engine/engine_core.py b/atom/model_engine/engine_core.py index 4fa412a40d..2103965ece 100644 --- a/atom/model_engine/engine_core.py +++ b/atom/model_engine/engine_core.py @@ -202,6 +202,7 @@ def _process_engine_step(self): self.output_queue.put_nowait(rejected) if result is None: + self._dispatch_idle_offload_work() if self.kv_transfer_enabled: kvoutput = self.runner_mgr.call_func_with_aggregation( "async_proc_aggregation" @@ -212,6 +213,7 @@ def _process_engine_step(self): if scheduled_batch is None: logger.debug("%s: No sequences to schedule, skipping forward", self.label) + self._dispatch_idle_offload_work() if self.kv_transfer_enabled: kvoutput = self.runner_mgr.call_func_with_aggregation( "async_proc_aggregation" @@ -268,6 +270,17 @@ def _process_engine_step(self): self.output_queue.put_nowait(finished_seqs) return True + def _dispatch_idle_offload_work(self) -> None: + if not self.kv_transfer_enabled: + return + connector = getattr(self.scheduler, "kv_connector", None) + if connector is None or not getattr(connector, "is_offload", False): + return + meta = connector.build_connector_meta() + if meta is None or not getattr(meta, "requests", None): + return + self.runner_mgr.call_func("process_kvconnector_output", meta) + def pull_and_process_input_queue(self): recv_reqs = [] while not self.input_queue.empty(): diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index ebd2823712..9d26aba5ca 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -50,6 +50,7 @@ set_forward_context, set_kv_cache_data, ) +from atom.kv_transfer.offload.trace import offload_trace from atom.utils.selector import get_attn_backend from torch.profiler import record_function @@ -2009,6 +2010,15 @@ def postprocess( @torch.inference_mode() def forward(self, batch: ScheduledBatch) -> ScheduledBatchOutput: + t_forward0 = time.perf_counter() + offload_trace( + "runner_forward_start", + reqs=batch.req_ids, + prefill=batch.total_seqs_num_prefill, + decode=batch.total_seqs_num_decode, + tokens=batch.total_tokens_num, + cached=batch.num_cached_tokens, + ) ( input_ids, temperatures, @@ -2029,6 +2039,15 @@ def forward(self, batch: ScheduledBatch) -> ScheduledBatchOutput: needs_independent_noise=needs_independent_noise, ) reset_forward_context() + offload_trace( + "runner_forward_done", + reqs=batch.req_ids, + out_reqs=fwd_output.req_ids, + prefill=batch.total_seqs_num_prefill, + decode=batch.total_seqs_num_decode, + tokens=batch.total_tokens_num, + total_ms=f"{(time.perf_counter() - t_forward0) * 1000:.2f}", + ) return fwd_output @torch.inference_mode() @@ -2037,6 +2056,12 @@ def process_kvconnector_output(self, connector_meta_output): if connector_meta_output is not None: connector = get_kvconnector() if connector is not None: + reqs = getattr(connector_meta_output, "requests", None) + offload_trace( + "runner_connector_dispatch", + requests=[r.req_id for r in reqs] if reqs is not None else None, + nrequests=len(reqs) if reqs is not None else None, + ) connector.start_load_kv(connector_meta_output) @torch.inference_mode() @@ -2048,6 +2073,19 @@ def async_proc_aggregation(self) -> KVConnectorOutput: finished = connector.get_finished() if isinstance(finished, KVConnectorOutput): + if ( + finished.finished_sending + or finished.finished_recving + or finished.failed_recving + or finished.finished_saving + ): + offload_trace( + "runner_connector_finished", + sending=sorted(finished.finished_sending or ()), + recving=sorted(finished.finished_recving or ()), + failed=sorted(finished.failed_recving or ()), + saving=sorted(finished.finished_saving or ()), + ) return finished done_sending, done_recving = finished diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index ac72b80953..991a108dd6 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -29,6 +29,7 @@ from atom.model_engine.block_manager import BlockManager from atom.model_engine.request import RequestOutput from atom.model_engine.sequence import Sequence, SequenceStatus, SequenceType +from atom.kv_transfer.offload.trace import offload_trace logger = logging.getLogger("atom") @@ -486,6 +487,8 @@ def _can_admit_head_prefill(self) -> bool: entries) and check `can_allocate` + token-budget, mirroring the same checks the admission while-loop runs below. """ + if self._partial_prefill_count > 0: + return True if not self.waiting: return False for i, seq in enumerate(self.waiting): @@ -496,7 +499,10 @@ def _can_admit_head_prefill(self) -> bool: if seq.status == SequenceStatus.WAITING_FOR_REMOTE_KVS: continue num_new_tokens = seq.num_tokens - seq.num_cached_tokens - if num_new_tokens > self.max_num_batched_tokens: + if ( + not self.enable_chunked_prefill + and num_new_tokens > self.max_num_batched_tokens + ): continue if self.block_manager.can_allocate(seq) < 0: return False # KV-pressured: definitely cannot prefill @@ -709,6 +715,12 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: waiting_remote_to_waiting_ready = False if seq.status == SequenceStatus.WAITING_FOR_REMOTE_KVS: if self._pop_req_id(self.failed_recving_kv_req_ids, seq.id): + offload_trace( + "scheduler_load_failed_wake", + req=seq.id, + cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + ) if self.kv_connector is not None and hasattr( self.kv_connector, "load_failed" ): @@ -745,6 +757,13 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: "[OFFLOAD-WAKE] seq %s: loaded=%s prev_cached=%d num_tokens=%d", seq.id, loaded, seq.num_cached_tokens, seq.num_tokens, ) + offload_trace( + "scheduler_offload_wake", + req=seq.id, + loaded=loaded, + prev_cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + ) if loaded is not None and loaded > seq.num_cached_tokens: seq.num_cached_tokens = loaded seq.offload_loaded = True @@ -822,13 +841,28 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: self.waiting.appendleft(seq) break chunk = num_new_tokens - + t_alloc0 = time.perf_counter() self.block_manager.allocate(seq, num_cached_blocks) + offload_trace( + "scheduler_alloc_done", + req=seq.id, + cached_blocks=num_cached_blocks, + cached_tokens=seq.num_cached_tokens, + blocks=len(seq.block_table), + alloc_ms=f"{(time.perf_counter() - t_alloc0) * 1000:.2f}", + ) if self.kv_connector is not None: self.kv_connector.update_state_after_alloc(seq) if need_to_remove_to_load_kv_async_queue: + offload_trace( + "scheduler_park_for_load", + req=seq.id, + cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + blocks=len(seq.block_table), + ) skipped_waiting_requests.append(seq) seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS continue @@ -837,15 +871,24 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: f"chunk must be positive: {chunk=}, " f"{num_new_tokens=}, {budget_remaining=}" ) + num_seqs_prefill += 1 if self.cache_stats: self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) num_batched_tokens += chunk - num_seqs_prefill += 1 seq.status = SequenceStatus.RUNNING seq.type = SequenceType.PREFILL self.running.append(seq) scheduled_seqs[seq.id] = seq num_scheduled_tokens.append(chunk) + offload_trace( + "scheduler_prefill_scheduled", + req=seq.id, + new_tokens=chunk, + cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + offload_loaded=getattr(seq, "offload_loaded", False), + load_failed=getattr(seq, "offload_load_failed", False), + ) if skipped_waiting_requests: logger.debug( @@ -894,10 +937,14 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_new_tokens = self.mtp_k + 1 remote_kv_blocks: set[int] = set() remote_kv_seq_blocks: dict[int, list[int]] = {} + skipped_partial_prefills: list[Sequence] = [] while self.running and num_seqs_decode < self.max_num_seqs: if num_decode_tokens + tokens_per_decode_seq > self.max_num_batched_tokens: break seq = self.running.popleft() + if seq.is_partial_prefill: + skipped_partial_prefills.append(seq) + continue while not self.block_manager.can_append(seq, num_new_tokens): if self.running: self.preempt(self.running.pop()) @@ -943,6 +990,8 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if scheduled_seqs: self.running.extendleft(reversed(scheduled_seqs.values())) + if skipped_partial_prefills: + self.running.extendleft(reversed(skipped_partial_prefills)) connector_meta_output = None if self.kv_connector is not None: @@ -1076,10 +1125,11 @@ def postprocess( # later in this loop, so they're not part of the prompt hash # chain — leaving them in would mint a stale partial-block hash. if not seq.prefix_hashes_published: - _num_new = seq.num_tokens - seq.num_cached_tokens - if need_placeholder: - _num_new -= num_placeholder - self.block_manager.hash_blocks(seq, _num_new) + if batch is None: + _num_new = seq.num_tokens - seq.num_cached_tokens + if need_placeholder: + _num_new -= num_placeholder + self.block_manager.hash_blocks(seq, max(0, _num_new)) seq.prefix_hashes_published = True token_ids = prev_token_ids[idx] num_new_token = len(token_ids) @@ -1255,6 +1305,8 @@ def postprocess( seq.is_partial_prefill = False self._partial_prefill_count -= 1 if self.kv_connector is not None: + if hasattr(self.kv_connector, "request_finished"): + self.kv_connector.request_finished(seq) if not self.kv_connector.is_producer: if hasattr(self.kv_connector, "should_defer_free") and ( self.kv_connector.should_defer_free(seq) @@ -1313,6 +1365,12 @@ def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool: return False logger.debug("KV transfer finished for seq %s, ready for scheduling.", seq.id) + offload_trace( + "scheduler_recv_ready", + req=seq.id, + cached=getattr(seq, "num_cached_tokens", None), + prompt=getattr(seq, "num_prompt_tokens", None), + ) return True def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): @@ -1339,6 +1397,7 @@ def _pop_deferred(req_id): not self.kv_connector.is_producer ), "Only consumer should update recving KV status" logger.debug("Finished recving KV transfer for request %s", req_id) + offload_trace("scheduler_finished_recving", req=req_id) self.finished_recving_kv_req_ids.append(req_id) for req_id in kv_connector_output.failed_recving or (): @@ -1346,6 +1405,7 @@ def _pop_deferred(req_id): not self.kv_connector.is_producer ), "Only consumer should update failed KV recv status" logger.warning("KV receive failed for request %s; falling back to prefill.", req_id) + offload_trace("scheduler_failed_recving", req=req_id) self.failed_recving_kv_req_ids.append(req_id) for req_id in kv_connector_output.finished_sending or (): @@ -1360,6 +1420,7 @@ def _pop_deferred(req_id): for req_id in kv_connector_output.finished_saving or (): if hasattr(self.kv_connector, "save_finished"): self.kv_connector.save_finished(req_id) + offload_trace("scheduler_finished_saving", req=req_id) seq = self.deferred_free_blocks.get(req_id) if seq is None: try: @@ -1370,7 +1431,11 @@ def _pop_deferred(req_id): hasattr(self.kv_connector, "should_defer_free") and self.kv_connector.should_defer_free(seq) ): - self.block_manager.deallocate(_pop_deferred(req_id)) + seq = _pop_deferred(req_id) + if seq is not None and hasattr(self.kv_connector, "request_finished"): + self.kv_connector.request_finished(seq) + if seq is not None: + self.block_manager.deallocate(seq) def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" @@ -1392,7 +1457,7 @@ def has_requests(self) -> bool: def get_next_batch_info(self) -> tuple[bool, int, int]: # Check for partial prefills in running (chunked prefill resume) for seq in self.running: - if seq.num_cached_tokens < seq.num_prompt_tokens: + if seq.is_partial_prefill: remaining = seq.num_prompt_tokens - seq.num_cached_tokens chunk = min(remaining, self.max_num_batched_tokens) return (True, chunk, 1) @@ -1409,6 +1474,8 @@ def get_next_batch_info(self) -> tuple[bool, int, int]: total_tokens = 0 for seq in eligible_waiting: tokens = seq.num_tokens - seq.num_cached_tokens + if self.enable_chunked_prefill: + tokens = min(tokens, self.max_num_batched_tokens - total_tokens) if total_tokens + tokens > self.max_num_batched_tokens: break if num_reqs >= self.max_num_seqs: diff --git a/tests/conftest.py b/tests/conftest.py index 326335cb9f..3fd4e233f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import importlib import importlib.util +import importlib.machinery import sys import os import types @@ -59,6 +60,16 @@ class _StubParallelConfig: _atom_config.ParallelConfig = _StubParallelConfig sys.modules["atom.config"] = _atom_config +# ── 3b. Stub forward_context; Scheduler only needs get_kvconnector in tests ── + +_forward_context = types.ModuleType("atom.utils.forward_context") +_forward_context.__package__ = "atom.utils" +_forward_context.__spec__ = importlib.machinery.ModuleSpec( + "atom.utils.forward_context", loader=None +) +_forward_context.get_kvconnector = lambda *args, **kwargs: None +sys.modules["atom.utils.forward_context"] = _forward_context + # ── 4. Stub zmq / zmq.asyncio if not installed ──────────────────────────── if importlib.util.find_spec("zmq") is None: @@ -116,6 +127,7 @@ def __init__(self, **overrides): kv_cache_block_size=4, num_kvcache_blocks=10, enable_prefix_caching=False, + enable_chunked_prefill=True, max_num_seqs=4, max_num_batched_tokens=64, max_model_len=64, @@ -124,7 +136,6 @@ def __init__(self, **overrides): stop_token_ids=[], scheduler_delay_factor=0.0, speculative_config=None, - enable_chunked_prefill=False, ) defaults.update(overrides) for k, v in defaults.items(): diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 558675eca9..85b0507dce 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -20,6 +20,7 @@ LMCacheOffloadConnector, LMCacheOffloadConnectorScheduler, ) +from atom.kv_transfer.offload import connector as offload_connector_mod from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec from atom.model_engine.scheduler import Scheduler @@ -27,10 +28,14 @@ class _LookupClient: def __init__(self, hit: int) -> None: self.hit = hit + self.cleared = [] def lookup(self, token_ids, lookup_id): return self.hit + def clear_lookup_status(self, lookup_id): + self.cleared.append(lookup_id) + def _scheduler() -> LMCacheOffloadConnectorScheduler: sched = LMCacheOffloadConnectorScheduler.__new__(LMCacheOffloadConnectorScheduler) @@ -44,6 +49,8 @@ def _scheduler() -> LMCacheOffloadConnectorScheduler: sched._save_tracker = {} sched._save_inflight = set() sched._lookup_in_step = [] + sched._lock = threading.Lock() + sched._done_load = set() return sched @@ -145,6 +152,104 @@ def test_segment_indexed_stitches_chunk_buffers(monkeypatch): assert torch.equal(stitched, direct) + split_buffers = [torch.empty_like(buf) for buf in chunk_buffers] + codec.split_request_buffer(stitched, split_buffers, [len(b) for b in chunks]) + for actual, expected in zip(split_buffers, chunk_buffers): + assert torch.equal(actual, expected) + + +@pytest.mark.parametrize("layout", ["block", "segment", "segment_indexed"]) +@pytest.mark.parametrize("method_name", ["gpu_to_host", "host_to_gpu"]) +def test_codec_rejects_invalid_block_ids_before_copy(monkeypatch, layout, method_name): + import torch + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", layout) + kv_caches = { + "l0": SimpleNamespace( + k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + v_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + k_scale=None, + v_scale=None, + ), + } + codec = ATOMKVByteCodec(kv_caches) + host = torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) + method = getattr(codec, method_name) + + with pytest.raises(ValueError, match="block id out of range"): + method(host, [0, 4]) + + with pytest.raises(ValueError, match="block id out of range"): + method(host, [-1]) + + +def test_codec_rejects_short_host_buffer(monkeypatch): + import torch + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + kv_caches = { + "l0": SimpleNamespace( + k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + v_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + k_scale=None, + v_scale=None, + ), + } + codec = ATOMKVByteCodec(kv_caches) + host = torch.empty(codec.bytes_per_block - 1, dtype=torch.uint8) + + with pytest.raises(ValueError, match="host_buf is too small"): + codec.gpu_to_host(host, [0]) + + +def test_copy_stream_is_cached_per_codec_device(monkeypatch): + import torch + if not hasattr(torch, "device") or not hasattr(torch, "cuda"): + pytest.skip("torch cuda API is unavailable") + + conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) + conn._tls = threading.local() + conn._codec = SimpleNamespace(device=torch.device("cuda:1")) + active_devices = [] + created_on = [] + + class _FakeDeviceCtx: + def __init__(self, device) -> None: + self.device = str(device) + + def __enter__(self): + active_devices.append(self.device) + return None + + def __exit__(self, *args): + active_devices.pop() + return False + + class _FakeStream: + def __init__(self) -> None: + created_on.append(active_devices[-1] if active_devices else "default") + + monkeypatch.setattr( + offload_connector_mod.torch.cuda, + "device", + lambda device: _FakeDeviceCtx(device), + ) + monkeypatch.setattr(offload_connector_mod.torch.cuda, "Stream", _FakeStream) + + rank1_stream = conn._stream() + assert conn._stream() is rank1_stream + assert created_on == ["cuda:1"] + + conn._codec = SimpleNamespace(device=torch.device("cuda:0")) + rank0_stream = conn._stream() + + assert rank0_stream is not rank1_stream + assert created_on == ["cuda:1", "cuda:0"] + def test_full_prompt_hit_is_clamped_before_load_spec(): sched = _scheduler() @@ -163,6 +268,121 @@ def test_full_prompt_hit_is_clamped_before_load_spec(): assert sched._load_specs[str(seq.id)].lmcache_cached_tokens == 7 +def test_load_is_skipped_if_hbm_satisfies_after_allocation(): + sched = _scheduler() + lookup = _LookupClient(hit=8) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=321, + num_prompt_tokens=12, + token_ids=list(range(12)), + num_cached_tokens=0, + block_table=[1, 2, 3], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 8 + assert should_park is True + + # Prefix-cache allocation can discover a larger HBM hit than the lookup-time + # snapshot. In that case the scheduler still emits a no-op load so the + # normal worker aggregation path can wake the parked seq. + seq.num_cached_tokens = 8 + sched.update_state_after_alloc(seq) + meta = sched.build_connector_meta() + + assert len(meta.requests) == 1 + req = meta.requests[0] + assert req.req_id == 321 + assert req.token_ids == list(range(8)) + assert req.block_ids == [1, 2, 3] + assert req.load_spec.hbm_cached_tokens == 8 + assert req.load_spec.lmcache_cached_tokens == 8 + assert seq.offload_loaded_tokens == 8 + assert lookup.cleared == [] + + +def test_load_is_skipped_if_hbm_floor_is_not_chunk_aligned(): + sched = _scheduler() + lookup = _LookupClient(hit=12) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=654, + num_prompt_tokens=16, + token_ids=list(range(16)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 12 + assert should_park is True + + # HBM prefix cache can return block-size granularity, while LMCache chunks + # are larger. Loading from a non-chunk boundary would either overlap shared + # HBM blocks or leave a gap, so the scheduler wakes the seq with a no-op + # load and lets suffix prefill continue from the HBM floor. + seq.num_cached_tokens = 6 + sched.update_state_after_alloc(seq) + meta = sched.build_connector_meta() + + assert len(meta.requests) == 1 + req = meta.requests[0] + assert req.req_id == 654 + assert req.token_ids == list(range(6)) + assert req.block_ids == [1, 2, 3, 4] + assert req.load_spec.hbm_cached_tokens == 6 + assert req.load_spec.lmcache_cached_tokens == 6 + assert seq.offload_loaded_tokens == 6 + + +def test_worker_completes_noop_load_when_hbm_satisfies(): + conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) + conn._lock = threading.Lock() + conn._done_load = set() + conn._failed_load = set() + conn._done_save = set() + conn._engine = SimpleNamespace(unpinned=[]) + conn._engine.lookup_unpin = lambda ids: conn._engine.unpinned.extend(ids) + + req = SimpleNamespace( + req_id=321, + token_ids=list(range(8)), + block_ids=[1, 2, 3], + load_spec=SimpleNamespace(hbm_cached_tokens=8, lmcache_cached_tokens=8), + ) + + conn._do_load_req(req) + + assert conn._done_load == {321} + assert conn._failed_load == set() + assert conn._engine.unpinned == ["321"] + + +def test_worker_reports_unaligned_hbm_load_as_failed_without_exception(): + conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) + conn._lock = threading.Lock() + conn._done_load = set() + conn._failed_load = set() + conn._done_save = set() + conn.chunk_size = 4 + conn._engine = SimpleNamespace(unpinned=[]) + conn._engine.lookup_unpin = lambda ids: conn._engine.unpinned.extend(ids) + + req = SimpleNamespace( + req_id=654, + token_ids=list(range(12)), + block_ids=[1, 2, 3], + load_spec=SimpleNamespace(hbm_cached_tokens=6, lmcache_cached_tokens=12), + ) + + conn._do_load_req(req) + + assert conn._done_load == set() + assert conn._failed_load == {654} + assert conn._engine.unpinned == ["654"] + + def test_load_exception_is_reported_as_failed_recving(): conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) conn._lock = threading.Lock() @@ -216,6 +436,7 @@ def test_save_inflight_defers_free_until_save_finishes(): token_ids=list(range(8)), block_table=[3, 4], num_prompt_tokens=8, + num_cached_tokens=8, prefix_hashes_published=True, ) sched._save_tracker[str(seq.id)] = [seq, 0] @@ -231,6 +452,40 @@ def test_save_inflight_defers_free_until_save_finishes(): assert sched.should_defer_free(seq) is False +def test_chunked_prefill_save_uses_computed_frontier_and_serializes_inflight(): + sched = _scheduler() + seq = SimpleNamespace( + id=10, + token_ids=list(range(12)), + block_table=[3, 4, 5], + num_prompt_tokens=12, + num_cached_tokens=8, + is_partial_prefill=True, + ) + sched._save_tracker[str(seq.id)] = [seq, 0] + + meta1 = sched.build_connector_meta() + + assert len(meta1.requests) == 1 + assert len(meta1.requests[0].token_ids) == 8 + assert meta1.requests[0].save_spec.skip_leading_tokens == 0 + assert meta1.requests[0].is_last_prefill is False + assert sched.should_defer_free(seq) is True + + seq.num_cached_tokens = 12 + seq.is_partial_prefill = False + meta2 = sched.build_connector_meta() + assert len(meta2.requests) == 0 + + sched.save_finished(seq.id) + meta3 = sched.build_connector_meta() + + assert len(meta3.requests) == 1 + assert len(meta3.requests[0].token_ids) == 12 + assert meta3.requests[0].save_spec.skip_leading_tokens == 8 + assert meta3.requests[0].is_last_prefill is True + + def test_finished_saving_releases_deferred_free_with_string_req_id(): class _BlockManager: def __init__(self) -> None: diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 5933e52443..e53d84b935 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -98,6 +98,42 @@ def test_prefill_respects_max_batched_tokens(self, seq_factory): assert batch.total_tokens_num_prefill == 6 assert list(batch.num_scheduled_tokens) == [4, 2] + def test_chunked_prefill_splits_prompt_across_steps(self, seq_factory): + sched = Scheduler( + MockConfig( + max_num_batched_tokens=6, + num_kvcache_blocks=100, + kv_cache_block_size=4, + enable_chunked_prefill=True, + ) + ) + seq = seq_factory(list(range(10))) + sched.add(seq) + + batch1, _ = sched.schedule() + assert batch1.total_tokens_num_prefill == 6 + assert list(batch1.scheduled_tokens) == list(range(6)) + assert list(batch1.num_cached_tokens) == [0] + + sched.postprocess( + list(sched.running), + ScheduledBatchOutput( + req_ids=[], + token_ids=[], + num_rejected=None, + num_bonus=None, + draft_token_ids=None, + ), + batch=batch1, + ) + assert seq.is_partial_prefill is True + assert seq.num_cached_tokens == 6 + + batch2, _ = sched.schedule() + assert batch2.total_tokens_num_prefill == 4 + assert list(batch2.scheduled_tokens) == list(range(6, 10)) + assert list(batch2.num_cached_tokens) == [6] + def test_prefill_respects_block_availability(self, seq_factory): sched = Scheduler(MockConfig(num_kvcache_blocks=1, kv_cache_block_size=4)) sched.add(seq_factory([1, 2, 3, 4])) # 1 block From 77dbd02d85fef54063fa86e8cfecf00a39a2a703 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Mon, 1 Jun 2026 03:52:43 -0500 Subject: [PATCH 04/27] WIP lmcache partial reload and HOL wake --- atom/kv_transfer/offload/connector.py | 426 ++++++++++++++++------ atom/kv_transfer/offload/gpu_connector.py | 80 ++++ atom/model_engine/scheduler.py | 58 +++ tests/test_lmcache_offload_connector.py | 260 +++++++++++-- tests/test_scheduler.py | 47 +++ 5 files changed, 745 insertions(+), 126 deletions(-) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 658c320fef..5f97a073c0 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -345,6 +345,20 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: toks=len(toks), blocks=len(req.block_ids), ) + + def fail_load(status: str, **fields) -> None: + self._lookup_unpin(req.req_id) + with self._lock: + self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status=status, + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + **fields, + ) + if int(ls.lmcache_cached_tokens) <= hbm: self._lookup_unpin(req.req_id) with self._lock: @@ -358,74 +372,131 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: ) return chunk_size = int(self.chunk_size or 256) - if hbm % chunk_size != 0: + block_size = int(getattr(self, "block_size", 1) or 1) + if hbm % block_size != 0: logger.warning( - "LMCache offload: HBM prefix is not chunk-aligned req=%s " - "hbm=%d chunk=%d; re-prefill", + "LMCache offload: HBM prefix is not block-aligned req=%s " + "hbm=%d block=%d; re-prefill", req.req_id, hbm, - chunk_size, + block_size, ) - self._lookup_unpin(req.req_id) - with self._lock: - self._failed_load.add(req.req_id) - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status="unaligned_hbm", + fail_load( + "unaligned_hbm", hbm=hbm, chunk=chunk_size, - total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + block=block_size, + ) + return + if chunk_size % block_size != 0: + logger.warning( + "LMCache offload: chunk size is not block-aligned req=%s " + "chunk=%d block=%d; re-prefill", + req.req_id, + chunk_size, + block_size, + ) + fail_load( + "unaligned_chunk_block", + chunk=chunk_size, + block=block_size, ) return + hbm_floor = (hbm // chunk_size) * chunk_size stream = self._stream() mask = torch.ones(len(toks), dtype=torch.bool) - mask[:hbm] = False + mask[:hbm_floor] = False t0 = time.perf_counter() - chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) + chunks = [ + (int(s), int(e), key) + for (s, e, key) in self._tdb.process_tokens(torch.tensor(toks), mask=mask) + if int(e) > hbm + ] process_ms = (time.perf_counter() - t0) * 1000 - logger.debug("offload _do_load req=%s hbm=%d lmc=%d chunks=%d", - req.req_id, hbm, ls.lmcache_cached_tokens, len(chunks)) + logger.debug( + "offload _do_load req=%s hbm=%d floor=%d lmc=%d chunks=%d", + req.req_id, + hbm, + hbm_floor, + ls.lmcache_cached_tokens, + len(chunks), + ) - # All-or-nothing above the HBM prefix: a partial load would let attention - # read uninitialized blocks, and a chunk that overlaps an HBM-cache hit - # could overwrite shared prefix-cache blocks. In either case the seq - # wakes and re-prefills from its HBM floor. + # All-or-nothing above the HBM prefix. When HBM lands inside an LMCache + # chunk, retrieve the overlapping chunk but copy only the block-aligned + # tail; shared HBM-hit blocks are owned by prefix cache and must not be + # written by this request's reload. if not chunks: logger.warning("LMCache offload: no loadable chunks req=%s; re-prefill", req.req_id) - self._lookup_unpin(req.req_id) - with self._lock: - self._failed_load.add(req.req_id) - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status="no_chunks", - total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", - ) + fail_load("no_chunks") return - for (s, _e, _key) in chunks: - if s < hbm: + copy_spans = [] + skipped_shared_blocks = 0 + partial_first_chunk = 0 + for (s, e, key) in chunks: + copy_start = max(s, hbm) + copy_end = e + if copy_end <= copy_start: + continue + if ( + s % block_size != 0 + or e % block_size != 0 + or copy_start % block_size != 0 + or copy_end % block_size != 0 + ): logger.warning( - "LMCache offload: chunk overlaps HBM prefix req=%s hbm=%d " - "chunk_start=%d; re-prefill", - req.req_id, hbm, s, + "LMCache offload: load span is not block-aligned req=%s " + "chunk=[%d,%d) copy=[%d,%d) block=%d; re-prefill", + req.req_id, + s, + e, + copy_start, + copy_end, + block_size, ) - self._lookup_unpin(req.req_id) - with self._lock: - self._failed_load.add(req.req_id) - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status="overlap_hbm", + fail_load( + "unaligned_copy_span", hbm=hbm, chunk_start=s, - total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + chunk_end=e, + copy_start=copy_start, + copy_end=copy_end, + block=block_size, + ) + return + chunk_block_count = (e - s) // block_size + src_skip_blocks = (copy_start - s) // block_size + bids = self._block_ids(req, copy_start, copy_end) + expected_blocks = (copy_end - copy_start) // block_size + if len(bids) != expected_blocks: + logger.warning( + "LMCache offload: block table too short req=%s " + "copy=[%d,%d) expected_blocks=%d got=%d; re-prefill", + req.req_id, + copy_start, + copy_end, + expected_blocks, + len(bids), + ) + fail_load( + "bad_block_table", + hbm=hbm, + copy_start=copy_start, + copy_end=copy_end, + expected_blocks=expected_blocks, + got_blocks=len(bids), ) return + if src_skip_blocks: + partial_first_chunk = 1 + skipped_shared_blocks += src_skip_blocks + copy_spans.append((s, e, key, bids, src_skip_blocks, chunk_block_count)) + if not copy_spans: + logger.warning("LMCache offload: no copy spans req=%s; re-prefill", + req.req_id) + fail_load("no_copy_spans") + return contains_ms = 0.0 loaded_objs = [] get_ms = 0.0 @@ -436,9 +507,7 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: nblocks = 0 nbytes = 0 copy_calls = 0 - chunk_bids: list[list[int]] = [ - self._block_ids(req, s, e) for (s, e, _key) in chunks - ] + chunk_bids: list[list[int]] = [span[3] for span in copy_spans] all_bids = [bid for bids in chunk_bids for bid in bids] nblocks = len(all_bids) nbytes = nblocks * self._codec.bytes_per_block @@ -476,6 +545,8 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: chunks=len(chunks), blocks=nblocks, bytes_gib=f"{nbytes / 1024**3:.3f}", + partial_first_chunk=partial_first_chunk, + skipped_shared_blocks=skipped_shared_blocks, stitch_ms=f"{stitch_ms:.2f}", h2d_submit_ms=f"{h2d_submit_ms:.2f}", sync_ms=f"{sync_ms:.2f}", @@ -483,8 +554,10 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: ) if self._profile_enabled(): logger.info( - "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " - "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " + "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d " + "hbm_floor=%d lmc=%d chunks=%d blocks=%d " + "bytes=%.3fGiB copy_calls=%d partial_first_chunk=%d " + "skipped_shared_blocks=%d " "layout=%s fastpath=request process_ms=%.2f " "contains_ms=%.2f get_ms=%.2f host_alloc_ms=%.2f " "stitch_ms=%.2f h2d_submit_ms=%.2f sync_ms=%.2f " @@ -492,11 +565,14 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: getattr(self, "_rank", "?"), req.req_id, hbm, + hbm_floor, ls.lmcache_cached_tokens, len(chunks), nblocks, nbytes / 1024**3, copy_calls, + partial_first_chunk, + skipped_shared_blocks, self._codec.layout, process_ms, contains_ms, @@ -533,7 +609,7 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: contains_ms += (time.perf_counter() - t0) * 1000 try: - for (s, e, key) in chunks: + for (s, e, key, bids, src_skip_blocks, chunk_block_count) in copy_spans: t0 = time.perf_counter() mo = self._sm.get(key) get_ms += (time.perf_counter() - t0) * 1000 @@ -556,27 +632,52 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: ) return loaded_objs.append(mo) - bids = chunk_bids[len(loaded_objs) - 1] if self._codec.layout != "segment_indexed": copy_calls += self._codec.copy_calls_for_block_ids(bids) t0 = time.perf_counter() - self._codec.host_to_gpu(mo.tensor, bids, stream) + self._codec.host_to_gpu_block_range( + mo.tensor, + src_skip_blocks, + bids, + stream, + src_block_count=chunk_block_count, + ) h2d_submit_ms += (time.perf_counter() - t0) * 1000 if self._codec.layout == "segment_indexed": - copy_calls = self._codec.copy_calls_for_block_ids(all_bids) - t0 = time.perf_counter() - req_buf = self._host_tmp(nbytes) - host_alloc_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - self._codec.stitch_chunk_buffers( - req_buf, - [mo.tensor for mo in loaded_objs], - [len(bids) for bids in chunk_bids], - ) - stitch_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - self._codec.host_to_gpu(req_buf, all_bids, stream) - h2d_submit_ms += (time.perf_counter() - t0) * 1000 + if partial_first_chunk: + for ( + _s, + _e, + _key, + bids, + src_skip_blocks, + chunk_block_count, + ), mo in zip(copy_spans, loaded_objs): + copy_calls += self._codec.copy_calls_for_block_ids(bids) + t0 = time.perf_counter() + self._codec.host_to_gpu_block_range( + mo.tensor, + src_skip_blocks, + bids, + stream, + src_block_count=chunk_block_count, + ) + h2d_submit_ms += (time.perf_counter() - t0) * 1000 + else: + copy_calls = self._codec.copy_calls_for_block_ids(all_bids) + t0 = time.perf_counter() + req_buf = self._host_tmp(nbytes) + host_alloc_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self._codec.stitch_chunk_buffers( + req_buf, + [mo.tensor for mo in loaded_objs], + [span[5] for span in copy_spans], + ) + stitch_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self._codec.host_to_gpu(req_buf, all_bids, stream) + h2d_submit_ms += (time.perf_counter() - t0) * 1000 t0 = time.perf_counter() stream.synchronize() sync_ms += (time.perf_counter() - t0) * 1000 @@ -606,6 +707,8 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: chunks=len(chunks), blocks=nblocks, bytes_gib=f"{nbytes / 1024**3:.3f}", + partial_first_chunk=partial_first_chunk, + skipped_shared_blocks=skipped_shared_blocks, stitch_ms=f"{stitch_ms:.2f}", h2d_submit_ms=f"{h2d_submit_ms:.2f}", sync_ms=f"{sync_ms:.2f}", @@ -613,20 +716,25 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: ) if self._profile_enabled(): logger.info( - "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " - "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " - "layout=%s fastpath=chunk process_ms=%.2f contains_ms=%.2f " + "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d hbm_floor=%d " + "lmc=%d chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " + "partial_first_chunk=%d skipped_shared_blocks=%d " + "layout=%s fastpath=%s process_ms=%.2f contains_ms=%.2f " "get_ms=%.2f host_alloc_ms=%.2f stitch_ms=%.2f " "h2d_submit_ms=%.2f sync_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, hbm, + hbm_floor, ls.lmcache_cached_tokens, len(chunks), nblocks, nbytes / 1024**3, copy_calls, + partial_first_chunk, + skipped_shared_blocks, self._codec.layout, + "chunk_partial" if partial_first_chunk else "chunk", process_ms, contains_ms, get_ms, @@ -1024,6 +1132,96 @@ def update_state_after_alloc(self, seq) -> None: if sid not in self._save_tracker: self._save_tracker[sid] = [seq, 0] + def _clear_pending_load(self, sid: str) -> None: + self._load_specs.pop(sid, None) + self._reqs_need_recv.pop(sid, None) + self._lookup_in_step = [x for x in self._lookup_in_step if x != sid] + if self._lookup_client is not None: + try: + self._lookup_client.clear_lookup_status(sid) + except Exception: + pass + + def should_park_for_load_after_alloc(self, seq) -> bool: + """Return True only when a real worker-side load is still needed. + + Lookup runs before ATOM's block allocation/prefix-cache match, so the + LoadSpec can become stale: allocation may discover an HBM hit that + already satisfies the LMCache hit. A chunk-unaligned HBM floor is still + loadable when it is block-aligned: the worker retrieves the overlapping + LMCache chunk and copies only the missing private tail blocks. + """ + sid = str(seq.id) + ls = self._load_specs.get(sid) + if ls is None: + return False + + ls.hbm_cached_tokens = int(seq.num_cached_tokens) + if ls.hbm_cached_tokens >= int(ls.lmcache_cached_tokens): + seq.offload_loaded_tokens = int(seq.num_cached_tokens) + logger.info( + "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " + "reason=hbm_satisfies_after_alloc", + seq.id, + ls.hbm_cached_tokens, + ls.lmcache_cached_tokens, + ) + offload_trace( + "scheduler_load_hbm_satisfies_after_alloc", + req=seq.id, + hbm=ls.hbm_cached_tokens, + lmc=ls.lmcache_cached_tokens, + blocks=len(list(seq.block_table)), + ) + self._clear_pending_load(sid) + return False + + chunk = self.chunk_size or 256 + if ls.hbm_cached_tokens % chunk != 0: + block = int(getattr(self, "block_size", 1) or 1) + if ls.hbm_cached_tokens % block != 0: + seq.offload_loaded_tokens = int(seq.num_cached_tokens) + logger.info( + "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " + "reason=unaligned_hbm chunk=%d block=%d", + seq.id, + ls.hbm_cached_tokens, + ls.lmcache_cached_tokens, + chunk, + block, + ) + offload_trace( + "scheduler_load_unaligned_hbm", + req=seq.id, + hbm=ls.hbm_cached_tokens, + lmc=ls.lmcache_cached_tokens, + chunk=chunk, + block=block, + blocks=len(list(seq.block_table)), + ) + self._clear_pending_load(sid) + return False + logger.info( + "[OFFLOAD-LOAD-PARTIAL] seq=%s hbm_cached=%d lmc_cached=%d " + "reason=partial_hbm_chunk chunk=%d block=%d", + seq.id, + ls.hbm_cached_tokens, + ls.lmcache_cached_tokens, + chunk, + block, + ) + offload_trace( + "scheduler_load_partial_hbm_chunk", + req=seq.id, + hbm=ls.hbm_cached_tokens, + lmc=ls.lmcache_cached_tokens, + chunk=chunk, + block=block, + blocks=len(list(seq.block_table)), + ) + + return True + def build_connector_meta(self) -> LMCacheOffloadMetadata: meta = LMCacheOffloadMetadata() meta.lookup_requests_in_step = self._lookup_in_step @@ -1041,9 +1239,9 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: # get_num_new_matched_tokens runs BEFORE the prefix-cache match in # block_manager.allocate, so seq.num_cached_tokens was stale (often # 0) when the LoadSpec was recorded. By now (post-allocate) it is the - # true HBM hit. Loading below this floor would overwrite HBM - # prefix-cache blocks (possibly shared with other seqs) -> output - # corruption. So load only [hbm_cached, offload_hit). + # true HBM hit. Loading below this floor would write through blocks + # owned by prefix cache and possibly shared with other seqs. So load + # only [hbm_cached, offload_hit). ls.hbm_cached_tokens = int(seq.num_cached_tokens) if ls.hbm_cached_tokens >= int(ls.lmcache_cached_tokens): seq.offload_loaded_tokens = int(seq.num_cached_tokens) @@ -1061,51 +1259,51 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: lmc=ls.lmcache_cached_tokens, blocks=len(list(seq.block_table)), ) - # The request may already be parked in WAITING_FOR_REMOTE_KVS. - # Emit a no-op load so every worker reports finished_recving via - # the normal aggregation path instead of trying to complete it - # locally in the scheduler process. - meta.add_request(LMCacheReqMeta( - req_id=seq.id, - token_ids=list(seq.token_ids[: ls.lmcache_cached_tokens]), - block_ids=list(seq.block_table), - load_spec=ls, - )) + self._clear_pending_load(sid) continue chunk = self.chunk_size or 256 if ls.hbm_cached_tokens % chunk != 0: - seq.offload_loaded_tokens = int(seq.num_cached_tokens) + block = int(getattr(self, "block_size", 1) or 1) + if ls.hbm_cached_tokens % block != 0: + seq.offload_loaded_tokens = int(seq.num_cached_tokens) + logger.info( + "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " + "reason=unaligned_hbm chunk=%d block=%d", + seq.id, + ls.hbm_cached_tokens, + ls.lmcache_cached_tokens, + chunk, + block, + ) + offload_trace( + "scheduler_load_unaligned_hbm", + req=seq.id, + hbm=ls.hbm_cached_tokens, + lmc=ls.lmcache_cached_tokens, + chunk=chunk, + block=block, + blocks=len(list(seq.block_table)), + ) + self._clear_pending_load(sid) + continue logger.info( - "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " - "reason=unaligned_hbm chunk=%d", + "[OFFLOAD-LOAD-PARTIAL] seq=%s hbm_cached=%d lmc_cached=%d " + "reason=partial_hbm_chunk chunk=%d block=%d", seq.id, ls.hbm_cached_tokens, ls.lmcache_cached_tokens, chunk, + block, ) offload_trace( - "scheduler_load_unaligned_hbm", + "scheduler_load_partial_hbm_chunk", req=seq.id, hbm=ls.hbm_cached_tokens, lmc=ls.lmcache_cached_tokens, chunk=chunk, + block=block, blocks=len(list(seq.block_table)), ) - # LMCache chunks can only be loaded from a chunk boundary. Do - # not round down and overwrite HBM prefix-cache blocks that may - # be shared with other requests; wake the parked request and let - # it continue prefill from the HBM floor. - meta.add_request(LMCacheReqMeta( - req_id=seq.id, - token_ids=list(seq.token_ids[: ls.hbm_cached_tokens]), - block_ids=list(seq.block_table), - load_spec=LoadSpec( - hbm_cached_tokens=ls.hbm_cached_tokens, - lmcache_cached_tokens=ls.hbm_cached_tokens, - can_load=True, - ), - )) - continue # num_cached after load = max(HBM, offload); never drop below HBM. seq.offload_loaded_tokens = max( int(seq.num_cached_tokens), int(ls.lmcache_cached_tokens) @@ -1113,15 +1311,33 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: # req_id MUST be the raw seq.id (the type the scheduler compares # against in _update_waiting_for_remote_kv); str(seq.id) is only for # LMCache's lookup/pin API. A str here silently never wakes the seq. - logger.info("[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d offload_loaded=%d nblocks=%d", - seq.id, ls.hbm_cached_tokens, ls.lmcache_cached_tokens, - seq.offload_loaded_tokens, len(list(seq.block_table))) + block = int(getattr(self, "block_size", 1) or 1) + partial_first_chunk = int(ls.hbm_cached_tokens % chunk != 0) + skipped_shared_blocks = 0 + if partial_first_chunk: + skipped_shared_blocks = ( + ls.hbm_cached_tokens - (ls.hbm_cached_tokens // chunk) * chunk + ) // block + logger.info( + "[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d " + "offload_loaded=%d nblocks=%d partial_first_chunk=%d " + "skipped_shared_blocks=%d", + seq.id, + ls.hbm_cached_tokens, + ls.lmcache_cached_tokens, + seq.offload_loaded_tokens, + len(list(seq.block_table)), + partial_first_chunk, + skipped_shared_blocks, + ) offload_trace( "scheduler_load_emit", req=seq.id, hbm=ls.hbm_cached_tokens, lmc=ls.lmcache_cached_tokens, offload_loaded=seq.offload_loaded_tokens, + partial_first_chunk=partial_first_chunk, + skipped_shared_blocks=skipped_shared_blocks, blocks=len(list(seq.block_table)), ) meta.add_request(LMCacheReqMeta( diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 9f4009bbee..767eaa7fc9 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -433,6 +433,86 @@ def host_to_gpu( non_blocking=True, ) + def host_to_gpu_block_range( + self, + host_buf: torch.Tensor, + src_block_start: int, + block_ids: list[int], + stream: torch.cuda.Stream | None = None, + src_block_count: int | None = None, + ) -> None: + """H2D a logical block subrange from a host buffer. + + ``host_buf`` is a complete LMCache object containing ``src_block_count`` + logical blocks in this codec's current layout. Copy the subrange starting + at ``src_block_start`` into ``block_ids``. This is needed when the first + LMCache chunk overlaps an HBM prefix-cache hit: the host object contains + the full chunk, but only the private tail blocks should be restored. + """ + src_block_start = operator.index(src_block_start) + block_ids = self._normalize_block_ids(block_ids) + if src_block_count is None: + if int(host_buf.numel()) % self.bytes_per_block != 0: + raise ValueError( + "ATOMKVByteCodec: host_buf size is not block-aligned" + ) + src_block_count = int(host_buf.numel()) // self.bytes_per_block + src_block_count = operator.index(src_block_count) + if src_block_start < 0: + raise ValueError("ATOMKVByteCodec: src_block_start must be non-negative") + if src_block_count < 0: + raise ValueError("ATOMKVByteCodec: src_block_count must be non-negative") + if src_block_start + len(block_ids) > src_block_count: + raise ValueError( + "ATOMKVByteCodec: source block range is out of bounds " + f"start={src_block_start} blocks={len(block_ids)} " + f"source_blocks={src_block_count}" + ) + self._validate_host_buf(host_buf, src_block_count) + if not block_ids: + return + + if self.layout == "block": + offset = src_block_start * self.bytes_per_block + self.host_to_gpu(host_buf[offset:], block_ids, stream) + return + + with self._device_ctx(): + stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with stream_ctx: + if self.layout == "segment_indexed": + idx = torch.tensor( + block_ids, dtype=torch.long, device=self._device + ) + bases = self._segment_bases(src_block_count) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + mat = self._segment_bytes_matrix(seg) + tmp = self._tmp_bytes(seg, len(block_ids)) + start = base + src_block_start * nb + tmp.copy_( + host_buf[start : start + len(block_ids) * nb].reshape_as( + tmp + ), + non_blocking=True, + ) + mat.index_copy_(0, idx, tmp) + return + + bases = self._segment_bases(src_block_count) + runs = list(self._contiguous_runs(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + for logical_start, physical_start, run_len in runs: + dst = self._blocks_bytes_view(seg, physical_start, run_len) + src = base + (src_block_start + logical_start) * nb + dst.copy_( + host_buf[src : src + run_len * nb], + non_blocking=True, + ) + class _nullctx: def __enter__(self): diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 991a108dd6..4d30b76429 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -650,6 +650,8 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_scheduled_tokens: list[int] = [] scheduled_spec_decode_tokens: dict[int, np.ndarray] = {} + self._promote_ready_remote_kv_requests() + # ─── Cross-DP prefill alignment (PrefillDelayer) ─────────────── _delayer_allows_prefill = True if self.prefill_delayer is not None: @@ -855,6 +857,12 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if self.kv_connector is not None: self.kv_connector.update_state_after_alloc(seq) + if need_to_remove_to_load_kv_async_queue: + if hasattr(self.kv_connector, "should_park_for_load_after_alloc"): + need_to_remove_to_load_kv_async_queue = ( + self.kv_connector.should_park_for_load_after_alloc(seq) + ) + if need_to_remove_to_load_kv_async_queue: offload_trace( "scheduler_park_for_load", @@ -1337,6 +1345,18 @@ def _is_offload_connector(self) -> bool: """ return getattr(self.kv_connector, "is_offload", False) + @staticmethod + def _has_req_id(req_ids: list, seq_id) -> bool: + candidates = (seq_id, str(seq_id)) + for candidate in candidates: + if candidate in req_ids: + return True + try: + int_id = int(seq_id) + except (TypeError, ValueError): + return False + return int_id in req_ids + @staticmethod def _pop_req_id(req_ids: list, seq_id) -> bool: candidates = (seq_id, str(seq_id)) @@ -1373,6 +1393,44 @@ def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool: ) return True + def _promote_ready_remote_kv_requests(self) -> None: + """Move completed remote-KV waiters ahead of fresh admissions. + + Offload/remote-KV waiters already own their allocated block table. If a + later fresh request reaches the head of ``waiting`` while HBM is full, + ``can_allocate()`` breaks the admission loop before the scheduler can + inspect and wake completed remote waiters behind it. The completed + waiters then cannot finish and free blocks, so the fresh request also + cannot allocate. Keep normal FIFO order otherwise. + """ + if not self.waiting or not ( + self.finished_recving_kv_req_ids or self.failed_recving_kv_req_ids + ): + return + + ready: deque[Sequence] = deque() + blocked: deque[Sequence] = deque() + while self.waiting: + seq = self.waiting.popleft() + if seq.status == SequenceStatus.WAITING_FOR_REMOTE_KVS and ( + self._has_req_id(self.finished_recving_kv_req_ids, seq.id) + or self._has_req_id(self.failed_recving_kv_req_ids, seq.id) + ): + ready.append(seq) + else: + blocked.append(seq) + + if ready: + offload_trace( + "scheduler_promote_remote_ready", + reqs=[seq.id for seq in ready], + rest=len(blocked), + ) + self.waiting.extend(ready) + self.waiting.extend(blocked) + else: + self.waiting.extend(blocked) + def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """Reconcile scheduler state with completed KV transfers. diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 85b0507dce..c2d4d000f9 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -158,6 +158,70 @@ def test_segment_indexed_stitches_chunk_buffers(monkeypatch): assert torch.equal(actual, expected) +@pytest.mark.parametrize("layout", ["block", "segment", "segment_indexed"]) +def test_codec_h2d_block_subrange(monkeypatch, layout): + import torch + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", layout) + original = { + "l0": SimpleNamespace( + k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), + v_cache=(torch.arange(8 * 4, dtype=torch.uint8).reshape(8, 4) + 51), + k_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 101, + v_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 151, + ), + "l1": SimpleNamespace( + k_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 201), + v_cache=(torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2) + 31), + k_scale=None, + v_scale=None, + ), + } + kv_caches = { + name: SimpleNamespace( + k_cache=layer.k_cache.clone(), + v_cache=layer.v_cache.clone(), + k_scale=layer.k_scale.clone() if layer.k_scale is not None else None, + v_scale=layer.v_scale.clone() if layer.v_scale is not None else None, + ) + for name, layer in original.items() + } + codec = ATOMKVByteCodec(kv_caches) + source_ids = [0, 1, 2, 3] + host = torch.empty(len(source_ids) * codec.bytes_per_block, dtype=torch.uint8) + codec.gpu_to_host(host, source_ids) + + for layer in kv_caches.values(): + layer.k_cache.zero_() + layer.v_cache.zero_() + if layer.k_scale is not None: + layer.k_scale.zero_() + if layer.v_scale is not None: + layer.v_scale.zero_() + + codec.host_to_gpu_block_range( + host, + src_block_start=1, + block_ids=[5, 7], + src_block_count=len(source_ids), + ) + + for name, layer in kv_caches.items(): + src = original[name] + assert torch.equal(layer.k_cache[5], src.k_cache[1]) + assert torch.equal(layer.v_cache[5], src.v_cache[1]) + assert torch.equal(layer.k_cache[7], src.k_cache[2]) + assert torch.equal(layer.v_cache[7], src.v_cache[2]) + assert torch.count_nonzero(layer.k_cache[0]).item() == 0 + if layer.k_scale is not None: + assert torch.equal(layer.k_scale[5], src.k_scale[1]) + assert torch.equal(layer.v_scale[5], src.v_scale[1]) + assert torch.equal(layer.k_scale[7], src.k_scale[2]) + assert torch.equal(layer.v_scale[7], src.v_scale[2]) + + @pytest.mark.parametrize("layout", ["block", "segment", "segment_indexed"]) @pytest.mark.parametrize("method_name", ["gpu_to_host", "host_to_gpu"]) def test_codec_rejects_invalid_block_ids_before_copy(monkeypatch, layout, method_name): @@ -285,24 +349,19 @@ def test_load_is_skipped_if_hbm_satisfies_after_allocation(): assert should_park is True # Prefix-cache allocation can discover a larger HBM hit than the lookup-time - # snapshot. In that case the scheduler still emits a no-op load so the - # normal worker aggregation path can wake the parked seq. + # snapshot. In that case the scheduler should not park the request or emit a + # worker no-op; it can continue prefill locally from the HBM floor. seq.num_cached_tokens = 8 sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is False meta = sched.build_connector_meta() - assert len(meta.requests) == 1 - req = meta.requests[0] - assert req.req_id == 321 - assert req.token_ids == list(range(8)) - assert req.block_ids == [1, 2, 3] - assert req.load_spec.hbm_cached_tokens == 8 - assert req.load_spec.lmcache_cached_tokens == 8 + assert [r for r in meta.requests if r.load_spec is not None] == [] assert seq.offload_loaded_tokens == 8 - assert lookup.cleared == [] + assert lookup.cleared == ["321"] -def test_load_is_skipped_if_hbm_floor_is_not_chunk_aligned(): +def test_load_is_skipped_if_hbm_floor_is_not_block_aligned(): sched = _scheduler() lookup = _LookupClient(hit=12) sched._lookup_client = lookup @@ -318,22 +377,81 @@ def test_load_is_skipped_if_hbm_floor_is_not_chunk_aligned(): assert need == 12 assert should_park is True - # HBM prefix cache can return block-size granularity, while LMCache chunks - # are larger. Loading from a non-chunk boundary would either overlap shared - # HBM blocks or leave a gap, so the scheduler wakes the seq with a no-op - # load and lets suffix prefill continue from the HBM floor. + # A floor inside an ATOM block cannot be restored with whole-block H2D + # copies, so the scheduler keeps the seq local and lets suffix prefill + # continue from the HBM floor. seq.num_cached_tokens = 6 sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is False + meta = sched.build_connector_meta() + + assert [r for r in meta.requests if r.load_spec is not None] == [] + assert seq.offload_loaded_tokens == 6 + + +def test_block_aligned_partial_hbm_chunk_still_parks_after_allocation(): + sched = _scheduler() + sched.chunk_size = 8 + sched.block_size = 4 + lookup = _LookupClient(hit=16) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=655, + num_prompt_tokens=24, + token_ids=list(range(24)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4, 5, 6], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 16 + assert should_park is True + + seq.num_cached_tokens = 12 + sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is True meta = sched.build_connector_meta() assert len(meta.requests) == 1 req = meta.requests[0] - assert req.req_id == 654 - assert req.token_ids == list(range(6)) + assert req.req_id == 655 + assert req.token_ids == list(range(16)) + assert req.block_ids == [1, 2, 3, 4, 5, 6] + assert req.load_spec.hbm_cached_tokens == 12 + assert req.load_spec.lmcache_cached_tokens == 16 + assert seq.offload_loaded_tokens == 16 + assert lookup.cleared == [] + + +def test_real_aligned_load_still_parks_after_allocation(): + sched = _scheduler() + lookup = _LookupClient(hit=12) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=777, + num_prompt_tokens=16, + token_ids=list(range(16)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 12 + assert should_park is True + + seq.num_cached_tokens = 4 + sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is True + meta = sched.build_connector_meta() + + assert len(meta.requests) == 1 + req = meta.requests[0] + assert req.req_id == 777 + assert req.token_ids == list(range(12)) assert req.block_ids == [1, 2, 3, 4] - assert req.load_spec.hbm_cached_tokens == 6 - assert req.load_spec.lmcache_cached_tokens == 6 - assert seq.offload_loaded_tokens == 6 + assert req.load_spec.hbm_cached_tokens == 4 + assert req.load_spec.lmcache_cached_tokens == 12 + assert lookup.cleared == [] def test_worker_completes_noop_load_when_hbm_satisfies(): @@ -359,13 +477,14 @@ def test_worker_completes_noop_load_when_hbm_satisfies(): assert conn._engine.unpinned == ["321"] -def test_worker_reports_unaligned_hbm_load_as_failed_without_exception(): +def test_worker_reports_non_block_aligned_hbm_load_as_failed_without_exception(): conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) conn._lock = threading.Lock() conn._done_load = set() conn._failed_load = set() conn._done_save = set() conn.chunk_size = 4 + conn.block_size = 4 conn._engine = SimpleNamespace(unpinned=[]) conn._engine.lookup_unpin = lambda ids: conn._engine.unpinned.extend(ids) @@ -383,6 +502,105 @@ def test_worker_reports_unaligned_hbm_load_as_failed_without_exception(): assert conn._engine.unpinned == ["654"] +def test_worker_partial_hbm_chunk_copies_only_missing_tail_blocks(): + import torch + if not hasattr(torch, "ones"): + pytest.skip("real torch is unavailable") + + conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) + conn._lock = threading.Lock() + conn._done_load = set() + conn._failed_load = set() + conn._done_save = set() + conn.chunk_size = 8 + conn.block_size = 4 + conn._request_fastpath = False + conn._rank = 0 + conn._engine = SimpleNamespace(unpinned=[]) + conn._engine.lookup_unpin = lambda ids: conn._engine.unpinned.extend(ids) + + class _Stream: + def __init__(self): + self.syncs = 0 + + def synchronize(self): + self.syncs += 1 + + stream = _Stream() + conn._stream = lambda: stream + + masks = [] + + class _TDB: + def process_tokens(self, token_ids, mask=None): + masks.append(mask.clone()) + return [(8, 16, "k1"), (16, 24, "k2")] + + class _MemoryObj: + def __init__(self, key): + self.key = key + self.tensor = torch.empty(2, dtype=torch.uint8) + self.refs = 1 + + def ref_count_down(self): + self.refs -= 1 + + class _SM: + def __init__(self): + self.objs = {"k1": _MemoryObj("k1"), "k2": _MemoryObj("k2")} + + def contains(self, key): + return key in self.objs + + def get(self, key): + return self.objs[key] + + copy_calls = [] + + class _Codec: + layout = "segment_indexed" + bytes_per_block = 1 + + def copy_calls_for_block_ids(self, block_ids): + return len(block_ids) + + def host_to_gpu_block_range( + self, + tensor, + src_block_start, + block_ids, + stream_arg=None, + src_block_count=None, + ): + copy_calls.append( + (src_block_start, list(block_ids), stream_arg, src_block_count) + ) + + conn._tdb = _TDB() + conn._sm = _SM() + conn._codec = _Codec() + + req = SimpleNamespace( + req_id=655, + token_ids=list(range(24)), + block_ids=[100, 101, 102, 103, 104, 105], + load_spec=SimpleNamespace(hbm_cached_tokens=12, lmcache_cached_tokens=24), + ) + + conn._do_load_req(req) + + assert masks + assert masks[0][:8].any().item() is False + assert masks[0][8:].all().item() is True + assert copy_calls == [ + (1, [103], stream, 2), + (0, [104, 105], stream, 2), + ] + assert conn._done_load == {655} + assert conn._failed_load == set() + assert conn._engine.unpinned == ["655"] + + def test_load_exception_is_reported_as_failed_recving(): conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) conn._lock = threading.Lock() diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index e53d84b935..3751fe6e0e 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -141,6 +141,53 @@ def test_prefill_respects_block_availability(self, seq_factory): batch, _ = sched.schedule() assert batch.total_seqs_num_prefill == 1 + def test_ready_remote_kv_waiter_bypasses_blocked_fresh_request(self, seq_factory): + class _OffloadConnector: + is_offload = True + is_producer = False + + def get_num_new_matched_tokens(self, seq): + return 0, False + + def build_connector_meta(self): + return None + + sched = Scheduler( + MockConfig( + num_kvcache_blocks=4, + kv_cache_block_size=4, + max_model_len=64, + max_num_batched_tokens=64, + ) + ) + sched.kv_connector = _OffloadConnector() + + blocked = seq_factory(list(range(12)), block_size=4) + remote = seq_factory(list(range(100, 116)), block_size=4) + remote.status = SequenceStatus.WAITING_FOR_REMOTE_KVS + remote.block_table = [0, 1, 2, 3] + remote.num_cached_tokens = 4 + remote.offload_loaded_tokens = 12 + + for block_id in remote.block_table: + block = sched.block_manager.blocks[block_id] + block.ref_count = 1 + sched.block_manager.used_block_ids = set(remote.block_table) + sched.block_manager.free_block_ids.clear() + sched.block_manager.free_block_ids_set.clear() + + sched.add(blocked) + sched.add(remote) + sched.finished_recving_kv_req_ids = [remote.id] + + batch, seqs = sched.schedule() + + assert remote.id in seqs + assert blocked.status == SequenceStatus.WAITING + assert remote.status == SequenceStatus.RUNNING + assert remote.num_cached_tokens == 12 + assert list(batch.scheduled_tokens) == remote.token_ids[12:16] + def test_decode_after_prefill(self, scheduler, seq_factory): seq = seq_factory([1, 2, 3, 4]) scheduler.add(seq) From 5f76dd7c492f224a3a3bbc9f4b27d88758cdaf5b Mon Sep 17 00:00:00 2001 From: yihonglie Date: Mon, 1 Jun 2026 04:33:12 -0500 Subject: [PATCH 05/27] Revert "WIP lmcache partial reload and HOL wake" This reverts commit 77dbd02d85fef54063fa86e8cfecf00a39a2a703. --- atom/kv_transfer/offload/connector.py | 426 ++++++---------------- atom/kv_transfer/offload/gpu_connector.py | 80 ---- atom/model_engine/scheduler.py | 58 --- tests/test_lmcache_offload_connector.py | 260 ++----------- tests/test_scheduler.py | 47 --- 5 files changed, 126 insertions(+), 745 deletions(-) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 5f97a073c0..658c320fef 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -345,20 +345,6 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: toks=len(toks), blocks=len(req.block_ids), ) - - def fail_load(status: str, **fields) -> None: - self._lookup_unpin(req.req_id) - with self._lock: - self._failed_load.add(req.req_id) - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status=status, - total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", - **fields, - ) - if int(ls.lmcache_cached_tokens) <= hbm: self._lookup_unpin(req.req_id) with self._lock: @@ -372,131 +358,74 @@ def fail_load(status: str, **fields) -> None: ) return chunk_size = int(self.chunk_size or 256) - block_size = int(getattr(self, "block_size", 1) or 1) - if hbm % block_size != 0: + if hbm % chunk_size != 0: logger.warning( - "LMCache offload: HBM prefix is not block-aligned req=%s " - "hbm=%d block=%d; re-prefill", + "LMCache offload: HBM prefix is not chunk-aligned req=%s " + "hbm=%d chunk=%d; re-prefill", req.req_id, hbm, - block_size, - ) - fail_load( - "unaligned_hbm", - hbm=hbm, - chunk=chunk_size, - block=block_size, - ) - return - if chunk_size % block_size != 0: - logger.warning( - "LMCache offload: chunk size is not block-aligned req=%s " - "chunk=%d block=%d; re-prefill", - req.req_id, chunk_size, - block_size, ) - fail_load( - "unaligned_chunk_block", + self._lookup_unpin(req.req_id) + with self._lock: + self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="unaligned_hbm", + hbm=hbm, chunk=chunk_size, - block=block_size, + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", ) return - hbm_floor = (hbm // chunk_size) * chunk_size stream = self._stream() mask = torch.ones(len(toks), dtype=torch.bool) - mask[:hbm_floor] = False + mask[:hbm] = False t0 = time.perf_counter() - chunks = [ - (int(s), int(e), key) - for (s, e, key) in self._tdb.process_tokens(torch.tensor(toks), mask=mask) - if int(e) > hbm - ] + chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) process_ms = (time.perf_counter() - t0) * 1000 - logger.debug( - "offload _do_load req=%s hbm=%d floor=%d lmc=%d chunks=%d", - req.req_id, - hbm, - hbm_floor, - ls.lmcache_cached_tokens, - len(chunks), - ) + logger.debug("offload _do_load req=%s hbm=%d lmc=%d chunks=%d", + req.req_id, hbm, ls.lmcache_cached_tokens, len(chunks)) - # All-or-nothing above the HBM prefix. When HBM lands inside an LMCache - # chunk, retrieve the overlapping chunk but copy only the block-aligned - # tail; shared HBM-hit blocks are owned by prefix cache and must not be - # written by this request's reload. + # All-or-nothing above the HBM prefix: a partial load would let attention + # read uninitialized blocks, and a chunk that overlaps an HBM-cache hit + # could overwrite shared prefix-cache blocks. In either case the seq + # wakes and re-prefills from its HBM floor. if not chunks: logger.warning("LMCache offload: no loadable chunks req=%s; re-prefill", req.req_id) - fail_load("no_chunks") + self._lookup_unpin(req.req_id) + with self._lock: + self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="no_chunks", + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", + ) return - copy_spans = [] - skipped_shared_blocks = 0 - partial_first_chunk = 0 - for (s, e, key) in chunks: - copy_start = max(s, hbm) - copy_end = e - if copy_end <= copy_start: - continue - if ( - s % block_size != 0 - or e % block_size != 0 - or copy_start % block_size != 0 - or copy_end % block_size != 0 - ): + for (s, _e, _key) in chunks: + if s < hbm: logger.warning( - "LMCache offload: load span is not block-aligned req=%s " - "chunk=[%d,%d) copy=[%d,%d) block=%d; re-prefill", - req.req_id, - s, - e, - copy_start, - copy_end, - block_size, + "LMCache offload: chunk overlaps HBM prefix req=%s hbm=%d " + "chunk_start=%d; re-prefill", + req.req_id, hbm, s, ) - fail_load( - "unaligned_copy_span", + self._lookup_unpin(req.req_id) + with self._lock: + self._failed_load.add(req.req_id) + offload_trace( + "worker_load_done", + rank=getattr(self, "_rank", "?"), + req=req.req_id, + status="overlap_hbm", hbm=hbm, chunk_start=s, - chunk_end=e, - copy_start=copy_start, - copy_end=copy_end, - block=block_size, - ) - return - chunk_block_count = (e - s) // block_size - src_skip_blocks = (copy_start - s) // block_size - bids = self._block_ids(req, copy_start, copy_end) - expected_blocks = (copy_end - copy_start) // block_size - if len(bids) != expected_blocks: - logger.warning( - "LMCache offload: block table too short req=%s " - "copy=[%d,%d) expected_blocks=%d got=%d; re-prefill", - req.req_id, - copy_start, - copy_end, - expected_blocks, - len(bids), - ) - fail_load( - "bad_block_table", - hbm=hbm, - copy_start=copy_start, - copy_end=copy_end, - expected_blocks=expected_blocks, - got_blocks=len(bids), + total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", ) return - if src_skip_blocks: - partial_first_chunk = 1 - skipped_shared_blocks += src_skip_blocks - copy_spans.append((s, e, key, bids, src_skip_blocks, chunk_block_count)) - if not copy_spans: - logger.warning("LMCache offload: no copy spans req=%s; re-prefill", - req.req_id) - fail_load("no_copy_spans") - return contains_ms = 0.0 loaded_objs = [] get_ms = 0.0 @@ -507,7 +436,9 @@ def fail_load(status: str, **fields) -> None: nblocks = 0 nbytes = 0 copy_calls = 0 - chunk_bids: list[list[int]] = [span[3] for span in copy_spans] + chunk_bids: list[list[int]] = [ + self._block_ids(req, s, e) for (s, e, _key) in chunks + ] all_bids = [bid for bids in chunk_bids for bid in bids] nblocks = len(all_bids) nbytes = nblocks * self._codec.bytes_per_block @@ -545,8 +476,6 @@ def fail_load(status: str, **fields) -> None: chunks=len(chunks), blocks=nblocks, bytes_gib=f"{nbytes / 1024**3:.3f}", - partial_first_chunk=partial_first_chunk, - skipped_shared_blocks=skipped_shared_blocks, stitch_ms=f"{stitch_ms:.2f}", h2d_submit_ms=f"{h2d_submit_ms:.2f}", sync_ms=f"{sync_ms:.2f}", @@ -554,10 +483,8 @@ def fail_load(status: str, **fields) -> None: ) if self._profile_enabled(): logger.info( - "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d " - "hbm_floor=%d lmc=%d chunks=%d blocks=%d " - "bytes=%.3fGiB copy_calls=%d partial_first_chunk=%d " - "skipped_shared_blocks=%d " + "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " + "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " "layout=%s fastpath=request process_ms=%.2f " "contains_ms=%.2f get_ms=%.2f host_alloc_ms=%.2f " "stitch_ms=%.2f h2d_submit_ms=%.2f sync_ms=%.2f " @@ -565,14 +492,11 @@ def fail_load(status: str, **fields) -> None: getattr(self, "_rank", "?"), req.req_id, hbm, - hbm_floor, ls.lmcache_cached_tokens, len(chunks), nblocks, nbytes / 1024**3, copy_calls, - partial_first_chunk, - skipped_shared_blocks, self._codec.layout, process_ms, contains_ms, @@ -609,7 +533,7 @@ def fail_load(status: str, **fields) -> None: contains_ms += (time.perf_counter() - t0) * 1000 try: - for (s, e, key, bids, src_skip_blocks, chunk_block_count) in copy_spans: + for (s, e, key) in chunks: t0 = time.perf_counter() mo = self._sm.get(key) get_ms += (time.perf_counter() - t0) * 1000 @@ -632,52 +556,27 @@ def fail_load(status: str, **fields) -> None: ) return loaded_objs.append(mo) + bids = chunk_bids[len(loaded_objs) - 1] if self._codec.layout != "segment_indexed": copy_calls += self._codec.copy_calls_for_block_ids(bids) t0 = time.perf_counter() - self._codec.host_to_gpu_block_range( - mo.tensor, - src_skip_blocks, - bids, - stream, - src_block_count=chunk_block_count, - ) + self._codec.host_to_gpu(mo.tensor, bids, stream) h2d_submit_ms += (time.perf_counter() - t0) * 1000 if self._codec.layout == "segment_indexed": - if partial_first_chunk: - for ( - _s, - _e, - _key, - bids, - src_skip_blocks, - chunk_block_count, - ), mo in zip(copy_spans, loaded_objs): - copy_calls += self._codec.copy_calls_for_block_ids(bids) - t0 = time.perf_counter() - self._codec.host_to_gpu_block_range( - mo.tensor, - src_skip_blocks, - bids, - stream, - src_block_count=chunk_block_count, - ) - h2d_submit_ms += (time.perf_counter() - t0) * 1000 - else: - copy_calls = self._codec.copy_calls_for_block_ids(all_bids) - t0 = time.perf_counter() - req_buf = self._host_tmp(nbytes) - host_alloc_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - self._codec.stitch_chunk_buffers( - req_buf, - [mo.tensor for mo in loaded_objs], - [span[5] for span in copy_spans], - ) - stitch_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - self._codec.host_to_gpu(req_buf, all_bids, stream) - h2d_submit_ms += (time.perf_counter() - t0) * 1000 + copy_calls = self._codec.copy_calls_for_block_ids(all_bids) + t0 = time.perf_counter() + req_buf = self._host_tmp(nbytes) + host_alloc_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self._codec.stitch_chunk_buffers( + req_buf, + [mo.tensor for mo in loaded_objs], + [len(bids) for bids in chunk_bids], + ) + stitch_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self._codec.host_to_gpu(req_buf, all_bids, stream) + h2d_submit_ms += (time.perf_counter() - t0) * 1000 t0 = time.perf_counter() stream.synchronize() sync_ms += (time.perf_counter() - t0) * 1000 @@ -707,8 +606,6 @@ def fail_load(status: str, **fields) -> None: chunks=len(chunks), blocks=nblocks, bytes_gib=f"{nbytes / 1024**3:.3f}", - partial_first_chunk=partial_first_chunk, - skipped_shared_blocks=skipped_shared_blocks, stitch_ms=f"{stitch_ms:.2f}", h2d_submit_ms=f"{h2d_submit_ms:.2f}", sync_ms=f"{sync_ms:.2f}", @@ -716,25 +613,20 @@ def fail_load(status: str, **fields) -> None: ) if self._profile_enabled(): logger.info( - "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d hbm_floor=%d " - "lmc=%d chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " - "partial_first_chunk=%d skipped_shared_blocks=%d " - "layout=%s fastpath=%s process_ms=%.2f contains_ms=%.2f " + "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " + "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " + "layout=%s fastpath=chunk process_ms=%.2f contains_ms=%.2f " "get_ms=%.2f host_alloc_ms=%.2f stitch_ms=%.2f " "h2d_submit_ms=%.2f sync_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, hbm, - hbm_floor, ls.lmcache_cached_tokens, len(chunks), nblocks, nbytes / 1024**3, copy_calls, - partial_first_chunk, - skipped_shared_blocks, self._codec.layout, - "chunk_partial" if partial_first_chunk else "chunk", process_ms, contains_ms, get_ms, @@ -1132,96 +1024,6 @@ def update_state_after_alloc(self, seq) -> None: if sid not in self._save_tracker: self._save_tracker[sid] = [seq, 0] - def _clear_pending_load(self, sid: str) -> None: - self._load_specs.pop(sid, None) - self._reqs_need_recv.pop(sid, None) - self._lookup_in_step = [x for x in self._lookup_in_step if x != sid] - if self._lookup_client is not None: - try: - self._lookup_client.clear_lookup_status(sid) - except Exception: - pass - - def should_park_for_load_after_alloc(self, seq) -> bool: - """Return True only when a real worker-side load is still needed. - - Lookup runs before ATOM's block allocation/prefix-cache match, so the - LoadSpec can become stale: allocation may discover an HBM hit that - already satisfies the LMCache hit. A chunk-unaligned HBM floor is still - loadable when it is block-aligned: the worker retrieves the overlapping - LMCache chunk and copies only the missing private tail blocks. - """ - sid = str(seq.id) - ls = self._load_specs.get(sid) - if ls is None: - return False - - ls.hbm_cached_tokens = int(seq.num_cached_tokens) - if ls.hbm_cached_tokens >= int(ls.lmcache_cached_tokens): - seq.offload_loaded_tokens = int(seq.num_cached_tokens) - logger.info( - "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " - "reason=hbm_satisfies_after_alloc", - seq.id, - ls.hbm_cached_tokens, - ls.lmcache_cached_tokens, - ) - offload_trace( - "scheduler_load_hbm_satisfies_after_alloc", - req=seq.id, - hbm=ls.hbm_cached_tokens, - lmc=ls.lmcache_cached_tokens, - blocks=len(list(seq.block_table)), - ) - self._clear_pending_load(sid) - return False - - chunk = self.chunk_size or 256 - if ls.hbm_cached_tokens % chunk != 0: - block = int(getattr(self, "block_size", 1) or 1) - if ls.hbm_cached_tokens % block != 0: - seq.offload_loaded_tokens = int(seq.num_cached_tokens) - logger.info( - "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " - "reason=unaligned_hbm chunk=%d block=%d", - seq.id, - ls.hbm_cached_tokens, - ls.lmcache_cached_tokens, - chunk, - block, - ) - offload_trace( - "scheduler_load_unaligned_hbm", - req=seq.id, - hbm=ls.hbm_cached_tokens, - lmc=ls.lmcache_cached_tokens, - chunk=chunk, - block=block, - blocks=len(list(seq.block_table)), - ) - self._clear_pending_load(sid) - return False - logger.info( - "[OFFLOAD-LOAD-PARTIAL] seq=%s hbm_cached=%d lmc_cached=%d " - "reason=partial_hbm_chunk chunk=%d block=%d", - seq.id, - ls.hbm_cached_tokens, - ls.lmcache_cached_tokens, - chunk, - block, - ) - offload_trace( - "scheduler_load_partial_hbm_chunk", - req=seq.id, - hbm=ls.hbm_cached_tokens, - lmc=ls.lmcache_cached_tokens, - chunk=chunk, - block=block, - blocks=len(list(seq.block_table)), - ) - - return True - def build_connector_meta(self) -> LMCacheOffloadMetadata: meta = LMCacheOffloadMetadata() meta.lookup_requests_in_step = self._lookup_in_step @@ -1239,9 +1041,9 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: # get_num_new_matched_tokens runs BEFORE the prefix-cache match in # block_manager.allocate, so seq.num_cached_tokens was stale (often # 0) when the LoadSpec was recorded. By now (post-allocate) it is the - # true HBM hit. Loading below this floor would write through blocks - # owned by prefix cache and possibly shared with other seqs. So load - # only [hbm_cached, offload_hit). + # true HBM hit. Loading below this floor would overwrite HBM + # prefix-cache blocks (possibly shared with other seqs) -> output + # corruption. So load only [hbm_cached, offload_hit). ls.hbm_cached_tokens = int(seq.num_cached_tokens) if ls.hbm_cached_tokens >= int(ls.lmcache_cached_tokens): seq.offload_loaded_tokens = int(seq.num_cached_tokens) @@ -1259,51 +1061,51 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: lmc=ls.lmcache_cached_tokens, blocks=len(list(seq.block_table)), ) - self._clear_pending_load(sid) + # The request may already be parked in WAITING_FOR_REMOTE_KVS. + # Emit a no-op load so every worker reports finished_recving via + # the normal aggregation path instead of trying to complete it + # locally in the scheduler process. + meta.add_request(LMCacheReqMeta( + req_id=seq.id, + token_ids=list(seq.token_ids[: ls.lmcache_cached_tokens]), + block_ids=list(seq.block_table), + load_spec=ls, + )) continue chunk = self.chunk_size or 256 if ls.hbm_cached_tokens % chunk != 0: - block = int(getattr(self, "block_size", 1) or 1) - if ls.hbm_cached_tokens % block != 0: - seq.offload_loaded_tokens = int(seq.num_cached_tokens) - logger.info( - "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " - "reason=unaligned_hbm chunk=%d block=%d", - seq.id, - ls.hbm_cached_tokens, - ls.lmcache_cached_tokens, - chunk, - block, - ) - offload_trace( - "scheduler_load_unaligned_hbm", - req=seq.id, - hbm=ls.hbm_cached_tokens, - lmc=ls.lmcache_cached_tokens, - chunk=chunk, - block=block, - blocks=len(list(seq.block_table)), - ) - self._clear_pending_load(sid) - continue + seq.offload_loaded_tokens = int(seq.num_cached_tokens) logger.info( - "[OFFLOAD-LOAD-PARTIAL] seq=%s hbm_cached=%d lmc_cached=%d " - "reason=partial_hbm_chunk chunk=%d block=%d", + "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " + "reason=unaligned_hbm chunk=%d", seq.id, ls.hbm_cached_tokens, ls.lmcache_cached_tokens, chunk, - block, ) offload_trace( - "scheduler_load_partial_hbm_chunk", + "scheduler_load_unaligned_hbm", req=seq.id, hbm=ls.hbm_cached_tokens, lmc=ls.lmcache_cached_tokens, chunk=chunk, - block=block, blocks=len(list(seq.block_table)), ) + # LMCache chunks can only be loaded from a chunk boundary. Do + # not round down and overwrite HBM prefix-cache blocks that may + # be shared with other requests; wake the parked request and let + # it continue prefill from the HBM floor. + meta.add_request(LMCacheReqMeta( + req_id=seq.id, + token_ids=list(seq.token_ids[: ls.hbm_cached_tokens]), + block_ids=list(seq.block_table), + load_spec=LoadSpec( + hbm_cached_tokens=ls.hbm_cached_tokens, + lmcache_cached_tokens=ls.hbm_cached_tokens, + can_load=True, + ), + )) + continue # num_cached after load = max(HBM, offload); never drop below HBM. seq.offload_loaded_tokens = max( int(seq.num_cached_tokens), int(ls.lmcache_cached_tokens) @@ -1311,33 +1113,15 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: # req_id MUST be the raw seq.id (the type the scheduler compares # against in _update_waiting_for_remote_kv); str(seq.id) is only for # LMCache's lookup/pin API. A str here silently never wakes the seq. - block = int(getattr(self, "block_size", 1) or 1) - partial_first_chunk = int(ls.hbm_cached_tokens % chunk != 0) - skipped_shared_blocks = 0 - if partial_first_chunk: - skipped_shared_blocks = ( - ls.hbm_cached_tokens - (ls.hbm_cached_tokens // chunk) * chunk - ) // block - logger.info( - "[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d " - "offload_loaded=%d nblocks=%d partial_first_chunk=%d " - "skipped_shared_blocks=%d", - seq.id, - ls.hbm_cached_tokens, - ls.lmcache_cached_tokens, - seq.offload_loaded_tokens, - len(list(seq.block_table)), - partial_first_chunk, - skipped_shared_blocks, - ) + logger.info("[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d offload_loaded=%d nblocks=%d", + seq.id, ls.hbm_cached_tokens, ls.lmcache_cached_tokens, + seq.offload_loaded_tokens, len(list(seq.block_table))) offload_trace( "scheduler_load_emit", req=seq.id, hbm=ls.hbm_cached_tokens, lmc=ls.lmcache_cached_tokens, offload_loaded=seq.offload_loaded_tokens, - partial_first_chunk=partial_first_chunk, - skipped_shared_blocks=skipped_shared_blocks, blocks=len(list(seq.block_table)), ) meta.add_request(LMCacheReqMeta( diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 767eaa7fc9..9f4009bbee 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -433,86 +433,6 @@ def host_to_gpu( non_blocking=True, ) - def host_to_gpu_block_range( - self, - host_buf: torch.Tensor, - src_block_start: int, - block_ids: list[int], - stream: torch.cuda.Stream | None = None, - src_block_count: int | None = None, - ) -> None: - """H2D a logical block subrange from a host buffer. - - ``host_buf`` is a complete LMCache object containing ``src_block_count`` - logical blocks in this codec's current layout. Copy the subrange starting - at ``src_block_start`` into ``block_ids``. This is needed when the first - LMCache chunk overlaps an HBM prefix-cache hit: the host object contains - the full chunk, but only the private tail blocks should be restored. - """ - src_block_start = operator.index(src_block_start) - block_ids = self._normalize_block_ids(block_ids) - if src_block_count is None: - if int(host_buf.numel()) % self.bytes_per_block != 0: - raise ValueError( - "ATOMKVByteCodec: host_buf size is not block-aligned" - ) - src_block_count = int(host_buf.numel()) // self.bytes_per_block - src_block_count = operator.index(src_block_count) - if src_block_start < 0: - raise ValueError("ATOMKVByteCodec: src_block_start must be non-negative") - if src_block_count < 0: - raise ValueError("ATOMKVByteCodec: src_block_count must be non-negative") - if src_block_start + len(block_ids) > src_block_count: - raise ValueError( - "ATOMKVByteCodec: source block range is out of bounds " - f"start={src_block_start} blocks={len(block_ids)} " - f"source_blocks={src_block_count}" - ) - self._validate_host_buf(host_buf, src_block_count) - if not block_ids: - return - - if self.layout == "block": - offset = src_block_start * self.bytes_per_block - self.host_to_gpu(host_buf[offset:], block_ids, stream) - return - - with self._device_ctx(): - stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() - with stream_ctx: - if self.layout == "segment_indexed": - idx = torch.tensor( - block_ids, dtype=torch.long, device=self._device - ) - bases = self._segment_bases(src_block_count) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - mat = self._segment_bytes_matrix(seg) - tmp = self._tmp_bytes(seg, len(block_ids)) - start = base + src_block_start * nb - tmp.copy_( - host_buf[start : start + len(block_ids) * nb].reshape_as( - tmp - ), - non_blocking=True, - ) - mat.index_copy_(0, idx, tmp) - return - - bases = self._segment_bases(src_block_count) - runs = list(self._contiguous_runs(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - for logical_start, physical_start, run_len in runs: - dst = self._blocks_bytes_view(seg, physical_start, run_len) - src = base + (src_block_start + logical_start) * nb - dst.copy_( - host_buf[src : src + run_len * nb], - non_blocking=True, - ) - class _nullctx: def __enter__(self): diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 4d30b76429..991a108dd6 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -650,8 +650,6 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_scheduled_tokens: list[int] = [] scheduled_spec_decode_tokens: dict[int, np.ndarray] = {} - self._promote_ready_remote_kv_requests() - # ─── Cross-DP prefill alignment (PrefillDelayer) ─────────────── _delayer_allows_prefill = True if self.prefill_delayer is not None: @@ -857,12 +855,6 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if self.kv_connector is not None: self.kv_connector.update_state_after_alloc(seq) - if need_to_remove_to_load_kv_async_queue: - if hasattr(self.kv_connector, "should_park_for_load_after_alloc"): - need_to_remove_to_load_kv_async_queue = ( - self.kv_connector.should_park_for_load_after_alloc(seq) - ) - if need_to_remove_to_load_kv_async_queue: offload_trace( "scheduler_park_for_load", @@ -1345,18 +1337,6 @@ def _is_offload_connector(self) -> bool: """ return getattr(self.kv_connector, "is_offload", False) - @staticmethod - def _has_req_id(req_ids: list, seq_id) -> bool: - candidates = (seq_id, str(seq_id)) - for candidate in candidates: - if candidate in req_ids: - return True - try: - int_id = int(seq_id) - except (TypeError, ValueError): - return False - return int_id in req_ids - @staticmethod def _pop_req_id(req_ids: list, seq_id) -> bool: candidates = (seq_id, str(seq_id)) @@ -1393,44 +1373,6 @@ def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool: ) return True - def _promote_ready_remote_kv_requests(self) -> None: - """Move completed remote-KV waiters ahead of fresh admissions. - - Offload/remote-KV waiters already own their allocated block table. If a - later fresh request reaches the head of ``waiting`` while HBM is full, - ``can_allocate()`` breaks the admission loop before the scheduler can - inspect and wake completed remote waiters behind it. The completed - waiters then cannot finish and free blocks, so the fresh request also - cannot allocate. Keep normal FIFO order otherwise. - """ - if not self.waiting or not ( - self.finished_recving_kv_req_ids or self.failed_recving_kv_req_ids - ): - return - - ready: deque[Sequence] = deque() - blocked: deque[Sequence] = deque() - while self.waiting: - seq = self.waiting.popleft() - if seq.status == SequenceStatus.WAITING_FOR_REMOTE_KVS and ( - self._has_req_id(self.finished_recving_kv_req_ids, seq.id) - or self._has_req_id(self.failed_recving_kv_req_ids, seq.id) - ): - ready.append(seq) - else: - blocked.append(seq) - - if ready: - offload_trace( - "scheduler_promote_remote_ready", - reqs=[seq.id for seq in ready], - rest=len(blocked), - ) - self.waiting.extend(ready) - self.waiting.extend(blocked) - else: - self.waiting.extend(blocked) - def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """Reconcile scheduler state with completed KV transfers. diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index c2d4d000f9..85b0507dce 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -158,70 +158,6 @@ def test_segment_indexed_stitches_chunk_buffers(monkeypatch): assert torch.equal(actual, expected) -@pytest.mark.parametrize("layout", ["block", "segment", "segment_indexed"]) -def test_codec_h2d_block_subrange(monkeypatch, layout): - import torch - if not hasattr(torch, "arange"): - pytest.skip("real torch is unavailable") - - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", layout) - original = { - "l0": SimpleNamespace( - k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), - v_cache=(torch.arange(8 * 4, dtype=torch.uint8).reshape(8, 4) + 51), - k_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 101, - v_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 151, - ), - "l1": SimpleNamespace( - k_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 201), - v_cache=(torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2) + 31), - k_scale=None, - v_scale=None, - ), - } - kv_caches = { - name: SimpleNamespace( - k_cache=layer.k_cache.clone(), - v_cache=layer.v_cache.clone(), - k_scale=layer.k_scale.clone() if layer.k_scale is not None else None, - v_scale=layer.v_scale.clone() if layer.v_scale is not None else None, - ) - for name, layer in original.items() - } - codec = ATOMKVByteCodec(kv_caches) - source_ids = [0, 1, 2, 3] - host = torch.empty(len(source_ids) * codec.bytes_per_block, dtype=torch.uint8) - codec.gpu_to_host(host, source_ids) - - for layer in kv_caches.values(): - layer.k_cache.zero_() - layer.v_cache.zero_() - if layer.k_scale is not None: - layer.k_scale.zero_() - if layer.v_scale is not None: - layer.v_scale.zero_() - - codec.host_to_gpu_block_range( - host, - src_block_start=1, - block_ids=[5, 7], - src_block_count=len(source_ids), - ) - - for name, layer in kv_caches.items(): - src = original[name] - assert torch.equal(layer.k_cache[5], src.k_cache[1]) - assert torch.equal(layer.v_cache[5], src.v_cache[1]) - assert torch.equal(layer.k_cache[7], src.k_cache[2]) - assert torch.equal(layer.v_cache[7], src.v_cache[2]) - assert torch.count_nonzero(layer.k_cache[0]).item() == 0 - if layer.k_scale is not None: - assert torch.equal(layer.k_scale[5], src.k_scale[1]) - assert torch.equal(layer.v_scale[5], src.v_scale[1]) - assert torch.equal(layer.k_scale[7], src.k_scale[2]) - assert torch.equal(layer.v_scale[7], src.v_scale[2]) - - @pytest.mark.parametrize("layout", ["block", "segment", "segment_indexed"]) @pytest.mark.parametrize("method_name", ["gpu_to_host", "host_to_gpu"]) def test_codec_rejects_invalid_block_ids_before_copy(monkeypatch, layout, method_name): @@ -349,86 +285,29 @@ def test_load_is_skipped_if_hbm_satisfies_after_allocation(): assert should_park is True # Prefix-cache allocation can discover a larger HBM hit than the lookup-time - # snapshot. In that case the scheduler should not park the request or emit a - # worker no-op; it can continue prefill locally from the HBM floor. + # snapshot. In that case the scheduler still emits a no-op load so the + # normal worker aggregation path can wake the parked seq. seq.num_cached_tokens = 8 sched.update_state_after_alloc(seq) - assert sched.should_park_for_load_after_alloc(seq) is False - meta = sched.build_connector_meta() - - assert [r for r in meta.requests if r.load_spec is not None] == [] - assert seq.offload_loaded_tokens == 8 - assert lookup.cleared == ["321"] - - -def test_load_is_skipped_if_hbm_floor_is_not_block_aligned(): - sched = _scheduler() - lookup = _LookupClient(hit=12) - sched._lookup_client = lookup - seq = SimpleNamespace( - id=654, - num_prompt_tokens=16, - token_ids=list(range(16)), - num_cached_tokens=0, - block_table=[1, 2, 3, 4], - ) - - need, should_park = sched.get_num_new_matched_tokens(seq) - assert need == 12 - assert should_park is True - - # A floor inside an ATOM block cannot be restored with whole-block H2D - # copies, so the scheduler keeps the seq local and lets suffix prefill - # continue from the HBM floor. - seq.num_cached_tokens = 6 - sched.update_state_after_alloc(seq) - assert sched.should_park_for_load_after_alloc(seq) is False - meta = sched.build_connector_meta() - - assert [r for r in meta.requests if r.load_spec is not None] == [] - assert seq.offload_loaded_tokens == 6 - - -def test_block_aligned_partial_hbm_chunk_still_parks_after_allocation(): - sched = _scheduler() - sched.chunk_size = 8 - sched.block_size = 4 - lookup = _LookupClient(hit=16) - sched._lookup_client = lookup - seq = SimpleNamespace( - id=655, - num_prompt_tokens=24, - token_ids=list(range(24)), - num_cached_tokens=0, - block_table=[1, 2, 3, 4, 5, 6], - ) - - need, should_park = sched.get_num_new_matched_tokens(seq) - assert need == 16 - assert should_park is True - - seq.num_cached_tokens = 12 - sched.update_state_after_alloc(seq) - assert sched.should_park_for_load_after_alloc(seq) is True meta = sched.build_connector_meta() assert len(meta.requests) == 1 req = meta.requests[0] - assert req.req_id == 655 - assert req.token_ids == list(range(16)) - assert req.block_ids == [1, 2, 3, 4, 5, 6] - assert req.load_spec.hbm_cached_tokens == 12 - assert req.load_spec.lmcache_cached_tokens == 16 - assert seq.offload_loaded_tokens == 16 + assert req.req_id == 321 + assert req.token_ids == list(range(8)) + assert req.block_ids == [1, 2, 3] + assert req.load_spec.hbm_cached_tokens == 8 + assert req.load_spec.lmcache_cached_tokens == 8 + assert seq.offload_loaded_tokens == 8 assert lookup.cleared == [] -def test_real_aligned_load_still_parks_after_allocation(): +def test_load_is_skipped_if_hbm_floor_is_not_chunk_aligned(): sched = _scheduler() lookup = _LookupClient(hit=12) sched._lookup_client = lookup seq = SimpleNamespace( - id=777, + id=654, num_prompt_tokens=16, token_ids=list(range(16)), num_cached_tokens=0, @@ -439,19 +318,22 @@ def test_real_aligned_load_still_parks_after_allocation(): assert need == 12 assert should_park is True - seq.num_cached_tokens = 4 + # HBM prefix cache can return block-size granularity, while LMCache chunks + # are larger. Loading from a non-chunk boundary would either overlap shared + # HBM blocks or leave a gap, so the scheduler wakes the seq with a no-op + # load and lets suffix prefill continue from the HBM floor. + seq.num_cached_tokens = 6 sched.update_state_after_alloc(seq) - assert sched.should_park_for_load_after_alloc(seq) is True meta = sched.build_connector_meta() assert len(meta.requests) == 1 req = meta.requests[0] - assert req.req_id == 777 - assert req.token_ids == list(range(12)) + assert req.req_id == 654 + assert req.token_ids == list(range(6)) assert req.block_ids == [1, 2, 3, 4] - assert req.load_spec.hbm_cached_tokens == 4 - assert req.load_spec.lmcache_cached_tokens == 12 - assert lookup.cleared == [] + assert req.load_spec.hbm_cached_tokens == 6 + assert req.load_spec.lmcache_cached_tokens == 6 + assert seq.offload_loaded_tokens == 6 def test_worker_completes_noop_load_when_hbm_satisfies(): @@ -477,14 +359,13 @@ def test_worker_completes_noop_load_when_hbm_satisfies(): assert conn._engine.unpinned == ["321"] -def test_worker_reports_non_block_aligned_hbm_load_as_failed_without_exception(): +def test_worker_reports_unaligned_hbm_load_as_failed_without_exception(): conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) conn._lock = threading.Lock() conn._done_load = set() conn._failed_load = set() conn._done_save = set() conn.chunk_size = 4 - conn.block_size = 4 conn._engine = SimpleNamespace(unpinned=[]) conn._engine.lookup_unpin = lambda ids: conn._engine.unpinned.extend(ids) @@ -502,105 +383,6 @@ def test_worker_reports_non_block_aligned_hbm_load_as_failed_without_exception() assert conn._engine.unpinned == ["654"] -def test_worker_partial_hbm_chunk_copies_only_missing_tail_blocks(): - import torch - if not hasattr(torch, "ones"): - pytest.skip("real torch is unavailable") - - conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) - conn._lock = threading.Lock() - conn._done_load = set() - conn._failed_load = set() - conn._done_save = set() - conn.chunk_size = 8 - conn.block_size = 4 - conn._request_fastpath = False - conn._rank = 0 - conn._engine = SimpleNamespace(unpinned=[]) - conn._engine.lookup_unpin = lambda ids: conn._engine.unpinned.extend(ids) - - class _Stream: - def __init__(self): - self.syncs = 0 - - def synchronize(self): - self.syncs += 1 - - stream = _Stream() - conn._stream = lambda: stream - - masks = [] - - class _TDB: - def process_tokens(self, token_ids, mask=None): - masks.append(mask.clone()) - return [(8, 16, "k1"), (16, 24, "k2")] - - class _MemoryObj: - def __init__(self, key): - self.key = key - self.tensor = torch.empty(2, dtype=torch.uint8) - self.refs = 1 - - def ref_count_down(self): - self.refs -= 1 - - class _SM: - def __init__(self): - self.objs = {"k1": _MemoryObj("k1"), "k2": _MemoryObj("k2")} - - def contains(self, key): - return key in self.objs - - def get(self, key): - return self.objs[key] - - copy_calls = [] - - class _Codec: - layout = "segment_indexed" - bytes_per_block = 1 - - def copy_calls_for_block_ids(self, block_ids): - return len(block_ids) - - def host_to_gpu_block_range( - self, - tensor, - src_block_start, - block_ids, - stream_arg=None, - src_block_count=None, - ): - copy_calls.append( - (src_block_start, list(block_ids), stream_arg, src_block_count) - ) - - conn._tdb = _TDB() - conn._sm = _SM() - conn._codec = _Codec() - - req = SimpleNamespace( - req_id=655, - token_ids=list(range(24)), - block_ids=[100, 101, 102, 103, 104, 105], - load_spec=SimpleNamespace(hbm_cached_tokens=12, lmcache_cached_tokens=24), - ) - - conn._do_load_req(req) - - assert masks - assert masks[0][:8].any().item() is False - assert masks[0][8:].all().item() is True - assert copy_calls == [ - (1, [103], stream, 2), - (0, [104, 105], stream, 2), - ] - assert conn._done_load == {655} - assert conn._failed_load == set() - assert conn._engine.unpinned == ["655"] - - def test_load_exception_is_reported_as_failed_recving(): conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) conn._lock = threading.Lock() diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 3751fe6e0e..e53d84b935 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -141,53 +141,6 @@ def test_prefill_respects_block_availability(self, seq_factory): batch, _ = sched.schedule() assert batch.total_seqs_num_prefill == 1 - def test_ready_remote_kv_waiter_bypasses_blocked_fresh_request(self, seq_factory): - class _OffloadConnector: - is_offload = True - is_producer = False - - def get_num_new_matched_tokens(self, seq): - return 0, False - - def build_connector_meta(self): - return None - - sched = Scheduler( - MockConfig( - num_kvcache_blocks=4, - kv_cache_block_size=4, - max_model_len=64, - max_num_batched_tokens=64, - ) - ) - sched.kv_connector = _OffloadConnector() - - blocked = seq_factory(list(range(12)), block_size=4) - remote = seq_factory(list(range(100, 116)), block_size=4) - remote.status = SequenceStatus.WAITING_FOR_REMOTE_KVS - remote.block_table = [0, 1, 2, 3] - remote.num_cached_tokens = 4 - remote.offload_loaded_tokens = 12 - - for block_id in remote.block_table: - block = sched.block_manager.blocks[block_id] - block.ref_count = 1 - sched.block_manager.used_block_ids = set(remote.block_table) - sched.block_manager.free_block_ids.clear() - sched.block_manager.free_block_ids_set.clear() - - sched.add(blocked) - sched.add(remote) - sched.finished_recving_kv_req_ids = [remote.id] - - batch, seqs = sched.schedule() - - assert remote.id in seqs - assert blocked.status == SequenceStatus.WAITING - assert remote.status == SequenceStatus.RUNNING - assert remote.num_cached_tokens == 12 - assert list(batch.scheduled_tokens) == remote.token_ids[12:16] - def test_decode_after_prefill(self, scheduler, seq_factory): seq = seq_factory([1, 2, 3, 4]) scheduler.add(seq) From 895ecdb580b37bb2152bb1cec3faeedcf6a17c6b Mon Sep 17 00:00:00 2001 From: yihonglie Date: Mon, 1 Jun 2026 21:58:23 -0500 Subject: [PATCH 06/27] Fix LMCache offload reload handoff Apply Scheme A load semantics after allocation: skip small or unsafe CPU reloads, hand off unaligned HBM floors by prefill-to-boundary when enabled, and promote completed remote-KV waiters before fresh admissions. Add scheduler and connector tests for aligned reloads, unaligned handoff, min-load skips, and deferred-output cleanup. --- atom/kv_transfer/offload/connector.py | 321 ++++++++++++++++++------ atom/model_engine/scheduler.py | 99 +++++++- tests/test_lmcache_offload_connector.py | 166 ++++++++++-- tests/test_scheduler.py | 80 ++++++ 4 files changed, 563 insertions(+), 103 deletions(-) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 658c320fef..f805724de0 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -910,6 +910,21 @@ def __init__(self, config) -> None: self._save_tracker: dict[str, list] = {} self._save_inflight: set[str] = set() self._lookup_in_step: list[str] = [] + self._handoff_loads: set[str] = set() + self._allow_unaligned_handoff = os.environ.get( + "OFFLOAD_UNALIGNED_HANDOFF", "0" + ).lower() in ("1", "true", "yes", "on") + try: + self._min_load_tokens = max( + 0, int(os.environ.get("OFFLOAD_MIN_LOAD_TOKENS", "8192")) + ) + except ValueError: + logger.warning( + "LMCache offload scheduler: invalid OFFLOAD_MIN_LOAD_TOKENS=%r; " + "using 8192", + os.environ.get("OFFLOAD_MIN_LOAD_TOKENS"), + ) + self._min_load_tokens = 8192 try: cfg = offcfg.build_lmcache_config() @@ -1024,14 +1039,207 @@ def update_state_after_alloc(self, seq) -> None: if sid not in self._save_tracker: self._save_tracker[sid] = [seq, 0] + def _clear_pending_load(self, sid: str) -> None: + self._load_specs.pop(sid, None) + self._reqs_need_recv.pop(sid, None) + self._handoff_loads.discard(sid) + self._lookup_in_step = [ + req_id for req_id in self._lookup_in_step if req_id != sid + ] + if self._lookup_client is not None: + try: + self._lookup_client.clear_lookup_status(sid) + except Exception: + pass + + def _decide_load_after_alloc( + self, seq, ls: LoadSpec + ) -> tuple[bool, str, int, int, int, int]: + hbm = int(getattr(seq, "num_cached_tokens", ls.hbm_cached_tokens)) + lmc = int(ls.lmcache_cached_tokens) + ls.hbm_cached_tokens = hbm + chunk = int(self.chunk_size or 256) + need = lmc - hbm + if lmc <= hbm: + return False, "hbm_satisfies_after_alloc", hbm, lmc, need, chunk + if hbm % chunk != 0: + return False, "unaligned_hbm_prefill", hbm, lmc, need, chunk + min_load = int(getattr(self, "_min_load_tokens", 8192)) + if need < min_load: + return False, "too_small", hbm, lmc, need, chunk + return True, "aligned_large_hit", hbm, lmc, need, chunk + + def _maybe_start_unaligned_handoff( + self, + seq, + ls: LoadSpec, + hbm: int, + lmc: int, + chunk: int, + ) -> bool: + if not getattr(self, "_allow_unaligned_handoff", False): + return False + boundary = ((hbm + chunk - 1) // chunk) * chunk + remaining_after_boundary = lmc - boundary + min_load = int(getattr(self, "_min_load_tokens", 8192)) + if boundary <= hbm or remaining_after_boundary < min_load: + return False + + sid = str(seq.id) + ls.hbm_cached_tokens = boundary + ls.can_load = True + self._reqs_need_recv.pop(sid, None) + self._handoff_loads.add(sid) + seq.offload_loaded_tokens = hbm + seq.offload_handoff_boundary_tokens = boundary + logger.info( + "[OFFLOAD-LOAD-HANDOFF] seq=%s hbm_cached=%d boundary=%d " + "lmc_cached=%d need_after_boundary=%d min_load=%d chunk=%d", + seq.id, + hbm, + boundary, + lmc, + remaining_after_boundary, + min_load, + chunk, + ) + offload_trace( + "scheduler_load_handoff_start", + req=seq.id, + hbm=hbm, + boundary=boundary, + lmc=lmc, + need_after_boundary=remaining_after_boundary, + min_load=min_load, + chunk=chunk, + blocks=len(list(seq.block_table)), + ) + return True + + def adjust_prefill_chunk_after_alloc(self, seq, chunk: int) -> int: + sid = str(seq.id) + if sid not in self._handoff_loads: + return chunk + boundary = getattr(seq, "offload_handoff_boundary_tokens", None) + if boundary is None: + return chunk + hbm = int(getattr(seq, "num_cached_tokens", 0)) + limit = int(boundary) - hbm + if limit <= 0: + return chunk + adjusted = min(int(chunk), limit) + offload_trace( + "scheduler_load_handoff_prefill_boundary", + req=seq.id, + hbm=hbm, + boundary=int(boundary), + original_chunk=int(chunk), + adjusted_chunk=adjusted, + ) + return max(1, adjusted) + + def should_park_partial_prefill_for_load(self, seq) -> bool: + sid = str(seq.id) + if sid not in self._handoff_loads: + return False + ls = self._load_specs.get(sid) + if ls is None: + self._handoff_loads.discard(sid) + return False + boundary = int(getattr(seq, "offload_handoff_boundary_tokens", 0) or 0) + hbm = int(getattr(seq, "num_cached_tokens", 0)) + if boundary > 0 and hbm < boundary: + return False + + should_load, reason, hbm, lmc, need, chunk = self._decide_load_after_alloc( + seq, ls + ) + if not should_load: + self._mark_load_skip(seq, reason, hbm, lmc, need, chunk) + self._clear_pending_load(sid) + return False + + ls.can_load = True + self._reqs_need_recv[sid] = seq + self._handoff_loads.discard(sid) + seq.offload_loaded_tokens = max(hbm, lmc) + logger.info( + "[OFFLOAD-LOAD-HANDOFF-READY] seq=%s hbm_cached=%d " + "lmc_cached=%d offload_loaded=%d need=%d", + seq.id, + hbm, + lmc, + seq.offload_loaded_tokens, + need, + ) + offload_trace( + "scheduler_load_handoff_ready", + req=seq.id, + hbm=hbm, + lmc=lmc, + need=need, + blocks=len(list(seq.block_table)), + ) + return True + + def _mark_load_skip( + self, + seq, + reason: str, + hbm: int, + lmc: int, + need: int, + chunk: int, + ) -> None: + seq.offload_loaded_tokens = hbm + min_load = int(getattr(self, "_min_load_tokens", 8192)) + logger.info( + "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " + "need=%d min_load=%d chunk=%d reason=%s", + seq.id, + hbm, + lmc, + need, + min_load, + chunk, + reason, + ) + offload_trace( + "scheduler_load_skip", + req=seq.id, + reason=reason, + hbm=hbm, + lmc=lmc, + need=need, + min_load=min_load, + chunk=chunk, + blocks=len(list(seq.block_table)), + ) + + def should_park_for_load_after_alloc(self, seq) -> bool: + sid = str(seq.id) + ls = self._load_specs.get(sid) + if ls is None: + return False + should_load, reason, hbm, lmc, need, chunk = self._decide_load_after_alloc(seq, ls) + if not should_load: + if reason == "unaligned_hbm_prefill" and self._maybe_start_unaligned_handoff( + seq, ls, hbm, lmc, chunk + ): + return False + self._mark_load_skip(seq, reason, hbm, lmc, need, chunk) + self._clear_pending_load(sid) + return False + seq.offload_loaded_tokens = max(hbm, lmc) + return True + def build_connector_meta(self) -> LMCacheOffloadMetadata: meta = LMCacheOffloadMetadata() - meta.lookup_requests_in_step = self._lookup_in_step - self._lookup_in_step = [] # Loads logger.debug("[OFFLOAD-BUILD] reqs_need_recv=%d", len(self._reqs_need_recv)) - for sid, seq in self._reqs_need_recv.items(): + loading_sids: set[str] = set() + for sid, seq in list(self._reqs_need_recv.items()): ls = self._load_specs.pop(sid, None) if ls is None or not ls.can_load: logger.debug("[OFFLOAD-LOAD-SKIP] seq=%s ls=%s can_load=%s", @@ -1044,99 +1252,53 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: # true HBM hit. Loading below this floor would overwrite HBM # prefix-cache blocks (possibly shared with other seqs) -> output # corruption. So load only [hbm_cached, offload_hit). - ls.hbm_cached_tokens = int(seq.num_cached_tokens) - if ls.hbm_cached_tokens >= int(ls.lmcache_cached_tokens): - seq.offload_loaded_tokens = int(seq.num_cached_tokens) - logger.info( - "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " - "reason=hbm_satisfies_after_alloc", - seq.id, - ls.hbm_cached_tokens, - ls.lmcache_cached_tokens, - ) - offload_trace( - "scheduler_load_hbm_satisfies_after_alloc", - req=seq.id, - hbm=ls.hbm_cached_tokens, - lmc=ls.lmcache_cached_tokens, - blocks=len(list(seq.block_table)), - ) - # The request may already be parked in WAITING_FOR_REMOTE_KVS. - # Emit a no-op load so every worker reports finished_recving via - # the normal aggregation path instead of trying to complete it - # locally in the scheduler process. - meta.add_request(LMCacheReqMeta( - req_id=seq.id, - token_ids=list(seq.token_ids[: ls.lmcache_cached_tokens]), - block_ids=list(seq.block_table), - load_spec=ls, - )) - continue - chunk = self.chunk_size or 256 - if ls.hbm_cached_tokens % chunk != 0: - seq.offload_loaded_tokens = int(seq.num_cached_tokens) - logger.info( - "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " - "reason=unaligned_hbm chunk=%d", - seq.id, - ls.hbm_cached_tokens, - ls.lmcache_cached_tokens, - chunk, - ) - offload_trace( - "scheduler_load_unaligned_hbm", - req=seq.id, - hbm=ls.hbm_cached_tokens, - lmc=ls.lmcache_cached_tokens, - chunk=chunk, - blocks=len(list(seq.block_table)), - ) - # LMCache chunks can only be loaded from a chunk boundary. Do - # not round down and overwrite HBM prefix-cache blocks that may - # be shared with other requests; wake the parked request and let - # it continue prefill from the HBM floor. - meta.add_request(LMCacheReqMeta( - req_id=seq.id, - token_ids=list(seq.token_ids[: ls.hbm_cached_tokens]), - block_ids=list(seq.block_table), - load_spec=LoadSpec( - hbm_cached_tokens=ls.hbm_cached_tokens, - lmcache_cached_tokens=ls.hbm_cached_tokens, - can_load=True, - ), - )) + should_load, reason, hbm, lmc, need, chunk = self._decide_load_after_alloc(seq, ls) + if not should_load: + self._mark_load_skip(seq, reason, hbm, lmc, need, chunk) + self._clear_pending_load(sid) continue # num_cached after load = max(HBM, offload); never drop below HBM. - seq.offload_loaded_tokens = max( - int(seq.num_cached_tokens), int(ls.lmcache_cached_tokens) - ) + seq.offload_loaded_tokens = max(hbm, lmc) # req_id MUST be the raw seq.id (the type the scheduler compares # against in _update_waiting_for_remote_kv); str(seq.id) is only for # LMCache's lookup/pin API. A str here silently never wakes the seq. - logger.info("[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d offload_loaded=%d nblocks=%d", - seq.id, ls.hbm_cached_tokens, ls.lmcache_cached_tokens, - seq.offload_loaded_tokens, len(list(seq.block_table))) + logger.info( + "[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d " + "offload_loaded=%d need=%d min_load=%d nblocks=%d reason=aligned_large_hit", + seq.id, + hbm, + lmc, + seq.offload_loaded_tokens, + need, + int(getattr(self, "_min_load_tokens", 8192)), + len(list(seq.block_table)), + ) offload_trace( "scheduler_load_emit", req=seq.id, - hbm=ls.hbm_cached_tokens, - lmc=ls.lmcache_cached_tokens, + hbm=hbm, + lmc=lmc, + need=need, + min_load=int(getattr(self, "_min_load_tokens", 8192)), offload_loaded=seq.offload_loaded_tokens, blocks=len(list(seq.block_table)), ) + loading_sids.add(sid) meta.add_request(LMCacheReqMeta( req_id=seq.id, - token_ids=list(seq.token_ids[: ls.lmcache_cached_tokens]), + token_ids=list(seq.token_ids[: lmc]), block_ids=list(seq.block_table), load_spec=ls, )) + meta.lookup_requests_in_step = self._lookup_in_step + self._lookup_in_step = [] # Saves: store fully computed prompt chunks. Under scheduler-side # chunked prefill, seq.num_cached_tokens advances after each prefill # chunk's forward has completed; use it as the D2H-safe frontier. chunk = self.chunk_size or 256 for sid, entry in self._save_tracker.items(): seq, saved = entry - if sid in self._reqs_need_recv: + if sid in self._reqs_need_recv or sid in loading_sids: continue # loading this step; defer its save if sid in self._save_inflight: continue # keep at most one save per request in flight @@ -1200,17 +1362,10 @@ def save_finished(self, req_id) -> None: self._save_inflight.discard(str(req_id)) def load_failed(self, req_id) -> None: - self._load_specs.pop(str(req_id), None) - self._reqs_need_recv.pop(str(req_id), None) + self._clear_pending_load(str(req_id)) def request_finished(self, seq) -> None: sid = str(seq.id) - self._load_specs.pop(sid, None) - self._reqs_need_recv.pop(sid, None) + self._clear_pending_load(sid) if not self.should_defer_free(seq): self._save_tracker.pop(sid, None) - if self._lookup_client is not None: - try: - self._lookup_client.clear_lookup_status(sid) - except Exception: - pass diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 991a108dd6..c15d2701d1 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -650,6 +650,9 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_scheduled_tokens: list[int] = [] scheduled_spec_decode_tokens: dict[int, np.ndarray] = {} + self._promote_ready_remote_kv_requests() + self._park_ready_offload_partial_prefills() + # ─── Cross-DP prefill alignment (PrefillDelayer) ─────────────── _delayer_allows_prefill = True if self.prefill_delayer is not None: @@ -855,6 +858,12 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if self.kv_connector is not None: self.kv_connector.update_state_after_alloc(seq) + if need_to_remove_to_load_kv_async_queue: + if hasattr(self.kv_connector, "should_park_for_load_after_alloc"): + need_to_remove_to_load_kv_async_queue = ( + self.kv_connector.should_park_for_load_after_alloc(seq) + ) + if need_to_remove_to_load_kv_async_queue: offload_trace( "scheduler_park_for_load", @@ -867,6 +876,13 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS continue + if self.kv_connector is not None and hasattr( + self.kv_connector, "adjust_prefill_chunk_after_alloc" + ): + chunk = self.kv_connector.adjust_prefill_chunk_after_alloc( + seq, chunk + ) + assert chunk > 0, ( f"chunk must be positive: {chunk=}, " f"{num_new_tokens=}, {budget_remaining=}" @@ -1100,7 +1116,11 @@ def postprocess( # Deferred output from a previous partial prefill step is garbage # under deferred-out: drop it once, then let the next step's real # first completion token populate the placeholder. - if seq.id in prev_partial_ids: + discard_deferred_output = False + if is_deferred_out and getattr(seq, "_discard_next_deferred_output", False): + seq._discard_next_deferred_output = False + discard_deferred_output = True + if seq.id in prev_partial_ids or discard_deferred_output: continue # Register prefix-cache hashes for blocks the prefill step just # finalized. Deferred from BlockManager.allocate() so a hash is @@ -1337,6 +1357,18 @@ def _is_offload_connector(self) -> bool: """ return getattr(self.kv_connector, "is_offload", False) + @staticmethod + def _has_req_id(req_ids: list, seq_id) -> bool: + candidates = (seq_id, str(seq_id)) + for candidate in candidates: + if candidate in req_ids: + return True + try: + int_id = int(seq_id) + except (TypeError, ValueError): + return False + return int_id in req_ids + @staticmethod def _pop_req_id(req_ids: list, seq_id) -> bool: candidates = (seq_id, str(seq_id)) @@ -1373,6 +1405,71 @@ def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool: ) return True + def _promote_ready_remote_kv_requests(self) -> None: + """Move completed remote-KV waiters ahead of fresh admissions. + + Offload waiters already own allocated blocks. If a fresh request at the + head cannot allocate while a completed waiter sits behind it, the waiter + cannot finish and free blocks. Preserve FIFO order within the ready and + blocked groups. + """ + if not self.waiting or not ( + self.finished_recving_kv_req_ids or self.failed_recving_kv_req_ids + ): + return + + ready: deque[Sequence] = deque() + blocked: deque[Sequence] = deque() + while self.waiting: + seq = self.waiting.popleft() + if seq.status == SequenceStatus.WAITING_FOR_REMOTE_KVS and ( + self._has_req_id(self.finished_recving_kv_req_ids, seq.id) + or self._has_req_id(self.failed_recving_kv_req_ids, seq.id) + ): + ready.append(seq) + else: + blocked.append(seq) + + if ready: + offload_trace( + "scheduler_promote_remote_ready", + reqs=[seq.id for seq in ready], + rest=len(blocked), + ) + self.waiting.extend(ready) + self.waiting.extend(blocked) + else: + self.waiting.extend(blocked) + + def _park_ready_offload_partial_prefills(self) -> None: + if not self.running or self.kv_connector is None or not hasattr( + self.kv_connector, "should_park_partial_prefill_for_load" + ): + return + + ready: deque[Sequence] = deque() + keep_running: deque[Sequence] = deque() + while self.running: + seq = self.running.popleft() + should_park = self.kv_connector.should_park_partial_prefill_for_load(seq) + if should_park: + if seq.is_partial_prefill: + seq._discard_next_deferred_output = True + seq.is_partial_prefill = False + self._partial_prefill_count -= 1 + seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS + ready.append(seq) + else: + keep_running.append(seq) + + self.running = keep_running + if ready: + offload_trace( + "scheduler_park_partial_prefill_for_load", + reqs=[seq.id for seq in ready], + ) + self.waiting.extendleft(reversed(ready)) + def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """Reconcile scheduler state with completed KV transfers. diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 85b0507dce..db738745df 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -49,6 +49,9 @@ def _scheduler() -> LMCacheOffloadConnectorScheduler: sched._save_tracker = {} sched._save_inflight = set() sched._lookup_in_step = [] + sched._handoff_loads = set() + sched._allow_unaligned_handoff = False + sched._min_load_tokens = 0 sched._lock = threading.Lock() sched._done_load = set() return sched @@ -285,21 +288,18 @@ def test_load_is_skipped_if_hbm_satisfies_after_allocation(): assert should_park is True # Prefix-cache allocation can discover a larger HBM hit than the lookup-time - # snapshot. In that case the scheduler still emits a no-op load so the - # normal worker aggregation path can wake the parked seq. + # snapshot. Scheme A should skip the CPU load before parking instead of + # emitting a no-op load. seq.num_cached_tokens = 8 sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is False meta = sched.build_connector_meta() - assert len(meta.requests) == 1 - req = meta.requests[0] - assert req.req_id == 321 - assert req.token_ids == list(range(8)) - assert req.block_ids == [1, 2, 3] - assert req.load_spec.hbm_cached_tokens == 8 - assert req.load_spec.lmcache_cached_tokens == 8 + assert [req for req in meta.requests if req.load_spec is not None] == [] assert seq.offload_loaded_tokens == 8 - assert lookup.cleared == [] + assert lookup.cleared == ["321"] + assert str(seq.id) not in sched._load_specs + assert str(seq.id) not in sched._reqs_need_recv def test_load_is_skipped_if_hbm_floor_is_not_chunk_aligned(): @@ -320,20 +320,148 @@ def test_load_is_skipped_if_hbm_floor_is_not_chunk_aligned(): # HBM prefix cache can return block-size granularity, while LMCache chunks # are larger. Loading from a non-chunk boundary would either overlap shared - # HBM blocks or leave a gap, so the scheduler wakes the seq with a no-op - # load and lets suffix prefill continue from the HBM floor. + # HBM blocks or leave a gap, so Scheme A skips CPU load and suffix-prefills + # from the HBM floor. seq.num_cached_tokens = 6 sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is False meta = sched.build_connector_meta() - assert len(meta.requests) == 1 - req = meta.requests[0] - assert req.req_id == 654 - assert req.token_ids == list(range(6)) - assert req.block_ids == [1, 2, 3, 4] - assert req.load_spec.hbm_cached_tokens == 6 - assert req.load_spec.lmcache_cached_tokens == 6 + assert [req for req in meta.requests if req.load_spec is not None] == [] + assert seq.offload_loaded_tokens == 6 + assert lookup.cleared == ["654"] + + +def test_unaligned_hbm_handoff_prefills_boundary_then_emits_load(): + sched = _scheduler() + sched._allow_unaligned_handoff = True + sched._min_load_tokens = 8 + lookup = _LookupClient(hit=16) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=657, + num_prompt_tokens=20, + token_ids=list(range(20)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4, 5], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 16 + assert should_park is True + + seq.num_cached_tokens = 6 + sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is False + assert str(seq.id) in sched._handoff_loads + assert seq.offload_handoff_boundary_tokens == 8 assert seq.offload_loaded_tokens == 6 + assert sched.adjust_prefill_chunk_after_alloc(seq, 10) == 2 + + seq.num_cached_tokens = 8 + assert sched.should_park_partial_prefill_for_load(seq) is True + meta = sched.build_connector_meta() + load_reqs = [req for req in meta.requests if req.load_spec is not None] + + assert len(load_reqs) == 1 + req = load_reqs[0] + assert req.req_id == 657 + assert req.token_ids == list(range(16)) + assert req.load_spec.hbm_cached_tokens == 8 + assert req.load_spec.lmcache_cached_tokens == 16 + assert seq.offload_loaded_tokens == 16 + assert str(seq.id) not in sched._handoff_loads + assert lookup.cleared == [] + + +def test_unaligned_handoff_skips_if_boundary_remainder_is_too_small(): + sched = _scheduler() + sched._allow_unaligned_handoff = True + sched._min_load_tokens = 8 + lookup = _LookupClient(hit=12) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=658, + num_prompt_tokens=16, + token_ids=list(range(16)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 12 + assert should_park is True + + seq.num_cached_tokens = 6 + sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is False + + assert str(seq.id) not in sched._handoff_loads + assert str(seq.id) not in sched._load_specs + assert str(seq.id) not in sched._reqs_need_recv + assert seq.offload_loaded_tokens == 6 + assert lookup.cleared == ["658"] + + +def test_load_is_skipped_if_aligned_hit_is_below_threshold(): + sched = _scheduler() + sched._min_load_tokens = 8 + lookup = _LookupClient(hit=12) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=655, + num_prompt_tokens=16, + token_ids=list(range(16)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 12 + assert should_park is True + + seq.num_cached_tokens = 8 + sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is False + meta = sched.build_connector_meta() + + assert [req for req in meta.requests if req.load_spec is not None] == [] + assert seq.offload_loaded_tokens == 8 + assert lookup.cleared == ["655"] + + +def test_aligned_large_hit_parks_and_emits_load_metadata(): + sched = _scheduler() + sched._min_load_tokens = 8 + lookup = _LookupClient(hit=12) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=656, + num_prompt_tokens=16, + token_ids=list(range(16)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 12 + assert should_park is True + + seq.num_cached_tokens = 4 + sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is True + meta = sched.build_connector_meta() + load_reqs = [req for req in meta.requests if req.load_spec is not None] + + assert len(load_reqs) == 1 + req = load_reqs[0] + assert req.req_id == 656 + assert req.token_ids == list(range(12)) + assert req.block_ids == [1, 2, 3, 4] + assert req.load_spec.hbm_cached_tokens == 4 + assert req.load_spec.lmcache_cached_tokens == 12 + assert seq.offload_loaded_tokens == 12 + assert lookup.cleared == [] def test_worker_completes_noop_load_when_hbm_satisfies(): diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index e53d84b935..0df27e66d8 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -2,6 +2,9 @@ # Tests for atom/model_engine/scheduler.py — public API only +from collections import deque +from types import SimpleNamespace + from atom.model_engine.scheduler import Scheduler, ScheduledBatchOutput, SpecStats from atom.model_engine.sequence import SequenceStatus, SequenceType from atom.sampling_params import SamplingParams @@ -166,6 +169,83 @@ def test_decode_preemption(self, seq_factory): assert SequenceStatus.RUNNING in statuses assert SequenceStatus.WAITING in statuses + def test_ready_remote_kv_waiter_is_promoted_ahead_of_fresh_head(self): + sched = Scheduler.__new__(Scheduler) + fresh = SimpleNamespace(id=1, status=SequenceStatus.WAITING) + ready = SimpleNamespace(id=2, status=SequenceStatus.WAITING_FOR_REMOTE_KVS) + blocked = SimpleNamespace(id=3, status=SequenceStatus.WAITING_FOR_REMOTE_KVS) + sched.waiting = deque([fresh, ready, blocked]) + sched.finished_recving_kv_req_ids = ["2"] + sched.failed_recving_kv_req_ids = [] + + sched._promote_ready_remote_kv_requests() + + assert [seq.id for seq in sched.waiting] == [2, 1, 3] + + def test_partial_prefill_ready_for_offload_load_moves_to_waiting(self): + class _Connector: + def should_park_partial_prefill_for_load(self, seq): + return seq.id == 2 + + sched = Scheduler.__new__(Scheduler) + sched.kv_connector = _Connector() + sched.waiting = deque() + sched._partial_prefill_count = 1 + keep = SimpleNamespace( + id=1, + status=SequenceStatus.RUNNING, + is_partial_prefill=False, + ) + ready = SimpleNamespace( + id=2, + status=SequenceStatus.RUNNING, + is_partial_prefill=True, + ) + sched.running = deque([keep, ready]) + + sched._park_ready_offload_partial_prefills() + + assert [seq.id for seq in sched.running] == [1] + assert [seq.id for seq in sched.waiting] == [2] + assert ready.status == SequenceStatus.WAITING_FOR_REMOTE_KVS + assert ready.is_partial_prefill is False + assert ready._discard_next_deferred_output is True + assert sched._partial_prefill_count == 0 + + def test_offload_partial_handoff_discards_stale_deferred_output(self, seq_factory): + sched = Scheduler( + MockConfig( + max_num_batched_tokens=64, + num_kvcache_blocks=10, + kv_cache_block_size=4, + enable_chunked_prefill=True, + ) + ) + seq = seq_factory(list(range(10)), sampling_params=SamplingParams(max_tokens=4)) + seq.status = SequenceStatus.RUNNING + seq.type = SequenceType.PREFILL + seq.num_cached_tokens = 8 + seq._discard_next_deferred_output = True + sched.running = deque([seq]) + + sched.postprocess( + [seq], + ScheduledBatchOutput( + req_ids=[seq.id], + token_ids=[(999,)], + num_rejected=[0], + num_bonus=[0], + draft_token_ids=None, + is_deferred_out=True, + ), + batch=SimpleNamespace(req_ids=[seq.id], num_scheduled_tokens=[2]), + ) + + assert seq.num_cached_tokens == 10 + assert seq._discard_next_deferred_output is False + assert 999 not in seq.output_tokens + assert seq.output_tokens == [sched.eos_token_id] + # ── prefix caching ──────────────────────────────────────────────────────── From 69748041ca6571734a8b5e120888e94e69719ab0 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Tue, 2 Jun 2026 03:09:51 -0500 Subject: [PATCH 07/27] Reduce offload resume scheduler diff noise --- atom/model_engine/scheduler.py | 120 ++++++++++++++++++++------------- 1 file changed, 72 insertions(+), 48 deletions(-) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index c15d2701d1..81b014fee9 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -824,65 +824,89 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: self.waiting.appendleft(seq) break chunk = num_new_tokens - else: - # Probe cache hits FIRST so budget check sees the real - # (post-prefix-cache) remaining token count. - num_cached_blocks = self.block_manager.can_allocate(seq) - if num_cached_blocks < 0: - self.waiting.appendleft(seq) - break - num_new_tokens = ( - seq.num_prompt_tokens - - num_cached_blocks * self.block_manager.block_size + assert chunk > 0, ( + f"chunk must be positive: {chunk=}, " + f"{num_new_tokens=}, {budget_remaining=}" ) - budget_remaining = self.max_num_batched_tokens - num_batched_tokens - if self.enable_chunked_prefill: - chunk = min(num_new_tokens, budget_remaining) - else: - if num_new_tokens > budget_remaining and num_batched_tokens > 0: - self.waiting.appendleft(seq) - break - chunk = num_new_tokens - t_alloc0 = time.perf_counter() - self.block_manager.allocate(seq, num_cached_blocks) + num_seqs_prefill += 1 + if self.cache_stats: + self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) + num_batched_tokens += chunk + seq.status = SequenceStatus.RUNNING + seq.type = SequenceType.PREFILL + self.running.append(seq) + scheduled_seqs[seq.id] = seq + num_scheduled_tokens.append(chunk) offload_trace( - "scheduler_alloc_done", + "scheduler_prefill_scheduled", req=seq.id, - cached_blocks=num_cached_blocks, - cached_tokens=seq.num_cached_tokens, - blocks=len(seq.block_table), - alloc_ms=f"{(time.perf_counter() - t_alloc0) * 1000:.2f}", + new_tokens=chunk, + cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + offload_loaded=getattr(seq, "offload_loaded", False), + load_failed=getattr(seq, "offload_load_failed", False), ) + continue - if self.kv_connector is not None: - self.kv_connector.update_state_after_alloc(seq) + # Probe cache hits FIRST so budget check sees the real + # (post-prefix-cache) remaining token count. + num_cached_blocks = self.block_manager.can_allocate(seq) + if num_cached_blocks < 0: + self.waiting.appendleft(seq) + break - if need_to_remove_to_load_kv_async_queue: - if hasattr(self.kv_connector, "should_park_for_load_after_alloc"): - need_to_remove_to_load_kv_async_queue = ( - self.kv_connector.should_park_for_load_after_alloc(seq) - ) + num_new_tokens = ( + seq.num_prompt_tokens + - num_cached_blocks * self.block_manager.block_size + ) + budget_remaining = self.max_num_batched_tokens - num_batched_tokens + if self.enable_chunked_prefill: + chunk = min(num_new_tokens, budget_remaining) + else: + if num_new_tokens > budget_remaining and num_batched_tokens > 0: + self.waiting.appendleft(seq) + break + chunk = num_new_tokens + t_alloc0 = time.perf_counter() + self.block_manager.allocate(seq, num_cached_blocks) + offload_trace( + "scheduler_alloc_done", + req=seq.id, + cached_blocks=num_cached_blocks, + cached_tokens=seq.num_cached_tokens, + blocks=len(seq.block_table), + alloc_ms=f"{(time.perf_counter() - t_alloc0) * 1000:.2f}", + ) - if need_to_remove_to_load_kv_async_queue: - offload_trace( - "scheduler_park_for_load", - req=seq.id, - cached=seq.num_cached_tokens, - prompt=seq.num_prompt_tokens, - blocks=len(seq.block_table), - ) - skipped_waiting_requests.append(seq) - seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS - continue + if self.kv_connector is not None: + self.kv_connector.update_state_after_alloc(seq) - if self.kv_connector is not None and hasattr( - self.kv_connector, "adjust_prefill_chunk_after_alloc" - ): - chunk = self.kv_connector.adjust_prefill_chunk_after_alloc( - seq, chunk + if need_to_remove_to_load_kv_async_queue: + if hasattr(self.kv_connector, "should_park_for_load_after_alloc"): + need_to_remove_to_load_kv_async_queue = ( + self.kv_connector.should_park_for_load_after_alloc(seq) ) + if need_to_remove_to_load_kv_async_queue: + offload_trace( + "scheduler_park_for_load", + req=seq.id, + cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + blocks=len(seq.block_table), + ) + skipped_waiting_requests.append(seq) + seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS + continue + + if self.kv_connector is not None and hasattr( + self.kv_connector, "adjust_prefill_chunk_after_alloc" + ): + chunk = self.kv_connector.adjust_prefill_chunk_after_alloc( + seq, chunk + ) + assert chunk > 0, ( f"chunk must be positive: {chunk=}, " f"{num_new_tokens=}, {budget_remaining=}" From af8ae03e5a35968430b0352faeaeb90689bb2bf3 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Tue, 2 Jun 2026 03:13:18 -0500 Subject: [PATCH 08/27] Format scheduler with Black --- atom/model_engine/scheduler.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 81b014fee9..ccb5156800 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -733,8 +733,8 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: seq.offload_loaded_tokens = seq.num_cached_tokens seq.offload_load_failed = True else: - waiting_remote_to_waiting_ready = self._update_waiting_for_remote_kv( - seq + waiting_remote_to_waiting_ready = ( + self._update_waiting_for_remote_kv(seq) ) if waiting_remote_to_waiting_ready: seq.status = SequenceStatus.WAITING @@ -758,7 +758,10 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: loaded = getattr(seq, "offload_loaded_tokens", None) logger.debug( "[OFFLOAD-WAKE] seq %s: loaded=%s prev_cached=%d num_tokens=%d", - seq.id, loaded, seq.num_cached_tokens, seq.num_tokens, + seq.id, + loaded, + seq.num_cached_tokens, + seq.num_tokens, ) offload_trace( "scheduler_offload_wake", @@ -903,9 +906,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if self.kv_connector is not None and hasattr( self.kv_connector, "adjust_prefill_chunk_after_alloc" ): - chunk = self.kv_connector.adjust_prefill_chunk_after_alloc( - seq, chunk - ) + chunk = self.kv_connector.adjust_prefill_chunk_after_alloc(seq, chunk) assert chunk > 0, ( f"chunk must be positive: {chunk=}, " @@ -1466,8 +1467,10 @@ def _promote_ready_remote_kv_requests(self) -> None: self.waiting.extend(blocked) def _park_ready_offload_partial_prefills(self) -> None: - if not self.running or self.kv_connector is None or not hasattr( - self.kv_connector, "should_park_partial_prefill_for_load" + if ( + not self.running + or self.kv_connector is None + or not hasattr(self.kv_connector, "should_park_partial_prefill_for_load") ): return @@ -1525,7 +1528,9 @@ def _pop_deferred(req_id): assert ( not self.kv_connector.is_producer ), "Only consumer should update failed KV recv status" - logger.warning("KV receive failed for request %s; falling back to prefill.", req_id) + logger.warning( + "KV receive failed for request %s; falling back to prefill.", req_id + ) offload_trace("scheduler_failed_recving", req=req_id) self.failed_recving_kv_req_ids.append(req_id) From e48edfa9d23056cdec7a227c7f0a48ad1f2117a3 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Tue, 2 Jun 2026 03:15:57 -0500 Subject: [PATCH 09/27] Fix offload formatting and lint --- atom/kv_transfer/disaggregation/aggregator.py | 5 +- atom/kv_transfer/offload/config.py | 11 +- atom/kv_transfer/offload/connector.py | 171 ++++++++++++------ atom/kv_transfer/offload/gpu_connector.py | 32 +--- atom/kv_transfer/offload/native_stitch.py | 9 +- tests/test_lmcache_offload_connector.py | 5 + 6 files changed, 150 insertions(+), 83 deletions(-) diff --git a/atom/kv_transfer/disaggregation/aggregator.py b/atom/kv_transfer/disaggregation/aggregator.py index 3fc6c2c227..da2546d132 100644 --- a/atom/kv_transfer/disaggregation/aggregator.py +++ b/atom/kv_transfer/disaggregation/aggregator.py @@ -110,7 +110,10 @@ def aggregate(self, worker_outputs: list[KVConnectorOutput]) -> KVConnectorOutpu for rid in recv_ids: done_workers = self._seen_recving.get(rid, set()) failed_workers = self._seen_recv_failed.get(rid, set()) - if failed_workers and len(done_workers | failed_workers) >= self._world_size: + if ( + failed_workers + and len(done_workers | failed_workers) >= self._world_size + ): failed_recving.add(rid) done_recving = { rid diff --git a/atom/kv_transfer/offload/config.py b/atom/kv_transfer/offload/config.py index 83fe7c2cf1..39b0368be8 100644 --- a/atom/kv_transfer/offload/config.py +++ b/atom/kv_transfer/offload/config.py @@ -13,11 +13,8 @@ from __future__ import annotations -import os from typing import Any -import torch - def build_lmcache_config(): """Return an ``LMCacheEngineConfig`` from ``LMCACHE_*`` env + extras.""" @@ -72,10 +69,14 @@ def build_lmcache_metadata(config, cfg, world_size: int, worker_id: int): hf = config.hf_config num_layers = int(getattr(hf, "num_hidden_layers")) - num_kv_heads = int(getattr(hf, "num_key_value_heads", getattr(hf, "num_attention_heads"))) + num_kv_heads = int( + getattr(hf, "num_key_value_heads", getattr(hf, "num_attention_heads")) + ) tp = int(getattr(config, "tensor_parallel_size", world_size) or 1) num_kv_heads_local = max(1, num_kv_heads // tp) - head_dim = int(getattr(hf, "head_dim", 0) or (hf.hidden_size // hf.num_attention_heads)) + head_dim = int( + getattr(hf, "head_dim", 0) or (hf.hidden_size // hf.num_attention_heads) + ) kv_dtype = dtypes.d_dtypes[config.kv_cache_dtype] model_name = str(getattr(config, "model", "atom-model")) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index f805724de0..d2daa69bff 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -132,13 +132,19 @@ def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: self._rank = rank cfg = offcfg.build_lmcache_config() - offcfg.apply_extra_overrides(cfg, getattr(self._config, "kv_transfer_config", None)) + offcfg.apply_extra_overrides( + cfg, getattr(self._config, "kv_transfer_config", None) + ) meta = offcfg.build_lmcache_metadata(self._config, cfg, world, rank) self.chunk_size = int(cfg.chunk_size) self._engine = LMCacheEngineBuilder.get_or_create( - f"atom-offload-{rank}", cfg, meta, _UnusedGPUConnector(), - lambda t, s: None, lambda o, s: o, + f"atom-offload-{rank}", + cfg, + meta, + _UnusedGPUConnector(), + lambda t, s: None, + lambda o, s: o, ) self._engine.post_init() self._sm = self._engine.storage_manager @@ -149,18 +155,26 @@ def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: # lookup_server makes on behalf of the scheduler) — args + result. _orig_lookup = self._engine.lookup _rk = rank + def _logged_lookup(*a, **k): r = _orig_lookup(*a, **k) h = k.get("hashes") - logger.debug("[ENGINE.LOOKUP] rank=%s lookup_id=%s nhashes=%s first3=%s -> %s", - _rk, k.get("lookup_id"), (len(h) if h is not None else None), - (list(h[:3]) if h else None), r) + logger.debug( + "[ENGINE.LOOKUP] rank=%s lookup_id=%s nhashes=%s first3=%s -> %s", + _rk, + k.get("lookup_id"), + (len(h) if h is not None else None), + (list(h[:3]) if h else None), + r, + ) return r + self._engine.lookup = _logged_lookup # ZMQ lookup server so the scheduler process can query our hit counts. try: from lmcache.v1.lookup_client.factory import LookupClientFactory + self._lookup_server = LookupClientFactory.create_lookup_server( self._engine, meta ) @@ -170,8 +184,12 @@ def _logged_lookup(*a, **k): logger.info( "LMCache offload worker rank=%d: bytes_per_block=%d chunk=%d " "codec_layout=%s save=%s load=%s", - rank, self._codec.bytes_per_block, self.chunk_size, self._codec.layout, - self._do_save, self._do_load, + rank, + self._codec.bytes_per_block, + self.chunk_size, + self._codec.layout, + self._do_save, + self._do_load, ) # -- per-step (RPC thread): only enqueue, never copy ------------------ @@ -210,7 +228,9 @@ def _guard(self, kind: str, fn, req) -> None: try: fn(req) except Exception: - logger.exception("LMCache offload: %s failed for %s", fn.__name__, req.req_id) + logger.exception( + "LMCache offload: %s failed for %s", fn.__name__, req.req_id + ) if kind == "load": self._lookup_unpin(req.req_id) with self._lock: @@ -385,16 +405,22 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: t0 = time.perf_counter() chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) process_ms = (time.perf_counter() - t0) * 1000 - logger.debug("offload _do_load req=%s hbm=%d lmc=%d chunks=%d", - req.req_id, hbm, ls.lmcache_cached_tokens, len(chunks)) + logger.debug( + "offload _do_load req=%s hbm=%d lmc=%d chunks=%d", + req.req_id, + hbm, + ls.lmcache_cached_tokens, + len(chunks), + ) # All-or-nothing above the HBM prefix: a partial load would let attention # read uninitialized blocks, and a chunk that overlaps an HBM-cache hit # could overwrite shared prefix-cache blocks. In either case the seq # wakes and re-prefills from its HBM floor. if not chunks: - logger.warning("LMCache offload: no loadable chunks req=%s; re-prefill", - req.req_id) + logger.warning( + "LMCache offload: no loadable chunks req=%s; re-prefill", req.req_id + ) self._lookup_unpin(req.req_id) with self._lock: self._failed_load.add(req.req_id) @@ -406,12 +432,14 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", ) return - for (s, _e, _key) in chunks: + for s, _e, _key in chunks: if s < hbm: logger.warning( "LMCache offload: chunk overlaps HBM prefix req=%s hbm=%d " "chunk_start=%d; re-prefill", - req.req_id, hbm, s, + req.req_id, + hbm, + s, ) self._lookup_unpin(req.req_id) with self._lock: @@ -513,11 +541,13 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: if req_mo is not None: req_mo.ref_count_down() - for (_s, _e, key) in chunks: + for _s, _e, key in chunks: t0 = time.perf_counter() if not self._sm.contains(key): contains_ms += (time.perf_counter() - t0) * 1000 - logger.warning("LMCache offload: load miss req=%s; re-prefill", req.req_id) + logger.warning( + "LMCache offload: load miss req=%s; re-prefill", req.req_id + ) self._lookup_unpin(req.req_id) with self._lock: self._failed_load.add(req.req_id) @@ -533,7 +563,7 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: contains_ms += (time.perf_counter() - t0) * 1000 try: - for (s, e, key) in chunks: + for s, e, key in chunks: t0 = time.perf_counter() mo = self._sm.get(key) get_ms += (time.perf_counter() - t0) * 1000 @@ -701,7 +731,7 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: copy_calls = 0 chunk_bids: list[list[int]] = [] try: - for (s, e, key) in chunks: + for s, e, key in chunks: self._pause_save_for_load(stream) t0 = time.perf_counter() if self._sm.contains(key): # already offloaded → skip wasted D2H @@ -712,8 +742,9 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: bids = self._block_ids(req, s, e) chunk_nbytes = len(bids) * self._codec.bytes_per_block t0 = time.perf_counter() - mo = self._sm.allocate(torch.Size((chunk_nbytes,)), torch.uint8, - fmt=MemoryFormat.KV_2LTD) + mo = self._sm.allocate( + torch.Size((chunk_nbytes,)), torch.uint8, fmt=MemoryFormat.KV_2LTD + ) alloc_ms += (time.perf_counter() - t0) * 1000 if mo is None: # pool under pressure; stop here break @@ -842,10 +873,18 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: if logger.isEnabledFor(logging.DEBUG): _kh = [getattr(k, "chunk_hash", None) for k in keys[:2]] _contains = [bool(self._sm.contains(k)) for k in keys[:2]] - logger.debug("[OFFLOAD-SAVE] rank=%s req=%s toks=%d chunks=%d stored=%d already=%d " - "chunkhash2=%s contains=%s", - self._rank, req.req_id, len(toks), len(chunks), len(keys), - already, _kh, _contains) + logger.debug( + "[OFFLOAD-SAVE] rank=%s req=%s toks=%d chunks=%d stored=%d already=%d " + "chunkhash2=%s contains=%s", + self._rank, + req.req_id, + len(toks), + len(chunks), + len(keys), + already, + _kh, + _contains, + ) # -- per-step (RPC thread, post-forward): poll completions ------------ def get_finished(self) -> KVConnectorOutput: @@ -931,11 +970,14 @@ def __init__(self, config) -> None: offcfg.apply_extra_overrides(cfg, kvc) self.chunk_size = int(cfg.chunk_size) from lmcache.v1.lookup_client.factory import LookupClientFactory + world = int(getattr(config, "tensor_parallel_size", 1) or 1) meta = offcfg.build_lmcache_metadata(config, cfg, world, 0) self._lookup_client = LookupClientFactory.create_lookup_client(cfg, meta) except Exception as e: - logger.warning("LMCache offload scheduler: lookup client unavailable: %s", e) + logger.warning( + "LMCache offload scheduler: lookup client unavailable: %s", e + ) # -- match: how many extra tokens can come from CPU/NVMe ------------- def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: @@ -967,12 +1009,22 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: try: tdb = getattr(self._lookup_client, "token_database", None) if tdb is not None: - _lh = [k for (_s, _e, k) in list( - tdb.process_tokens(token_ids, make_key=False))[:3]] + _lh = [ + k + for (_s, _e, k) in list( + tdb.process_tokens(token_ids, make_key=False) + )[:3] + ] except Exception as e: _lh = f"err:{e}" - logger.debug("[OFFLOAD-LOOKUP] seq=%s num_prompt=%d hbm_cached=%d hit=%s lookuphash3=%s", - seq.id, num_prompt, int(seq.num_cached_tokens), hit, _lh) + logger.debug( + "[OFFLOAD-LOOKUP] seq=%s num_prompt=%d hbm_cached=%d hit=%s lookuphash3=%s", + seq.id, + num_prompt, + int(seq.num_cached_tokens), + hit, + _lh, + ) if not hit: offload_trace( "scheduler_lookup_done", @@ -1022,8 +1074,12 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: def update_state_after_alloc(self, seq) -> None: sid = str(seq.id) ls = self._load_specs.get(sid) - logger.debug("[OFFLOAD-ALLOC] seq=%s ls_found=%s num_cached_now=%s", - seq.id, ls is not None, int(getattr(seq, "num_cached_tokens", -1))) + logger.debug( + "[OFFLOAD-ALLOC] seq=%s ls_found=%s num_cached_now=%s", + seq.id, + ls is not None, + int(getattr(seq, "num_cached_tokens", -1)), + ) if ls is not None: ls.can_load = True self._reqs_need_recv[sid] = seq @@ -1221,10 +1277,13 @@ def should_park_for_load_after_alloc(self, seq) -> bool: ls = self._load_specs.get(sid) if ls is None: return False - should_load, reason, hbm, lmc, need, chunk = self._decide_load_after_alloc(seq, ls) + should_load, reason, hbm, lmc, need, chunk = self._decide_load_after_alloc( + seq, ls + ) if not should_load: - if reason == "unaligned_hbm_prefill" and self._maybe_start_unaligned_handoff( - seq, ls, hbm, lmc, chunk + if ( + reason == "unaligned_hbm_prefill" + and self._maybe_start_unaligned_handoff(seq, ls, hbm, lmc, chunk) ): return False self._mark_load_skip(seq, reason, hbm, lmc, need, chunk) @@ -1242,8 +1301,12 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: for sid, seq in list(self._reqs_need_recv.items()): ls = self._load_specs.pop(sid, None) if ls is None or not ls.can_load: - logger.debug("[OFFLOAD-LOAD-SKIP] seq=%s ls=%s can_load=%s", - sid, ls is not None, getattr(ls, "can_load", None)) + logger.debug( + "[OFFLOAD-LOAD-SKIP] seq=%s ls=%s can_load=%s", + sid, + ls is not None, + getattr(ls, "can_load", None), + ) continue # ★ Use the REAL HBM-cached count as the load floor. # get_num_new_matched_tokens runs BEFORE the prefix-cache match in @@ -1252,7 +1315,9 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: # true HBM hit. Loading below this floor would overwrite HBM # prefix-cache blocks (possibly shared with other seqs) -> output # corruption. So load only [hbm_cached, offload_hit). - should_load, reason, hbm, lmc, need, chunk = self._decide_load_after_alloc(seq, ls) + should_load, reason, hbm, lmc, need, chunk = self._decide_load_after_alloc( + seq, ls + ) if not should_load: self._mark_load_skip(seq, reason, hbm, lmc, need, chunk) self._clear_pending_load(sid) @@ -1284,12 +1349,14 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: blocks=len(list(seq.block_table)), ) loading_sids.add(sid) - meta.add_request(LMCacheReqMeta( - req_id=seq.id, - token_ids=list(seq.token_ids[: lmc]), - block_ids=list(seq.block_table), - load_spec=ls, - )) + meta.add_request( + LMCacheReqMeta( + req_id=seq.id, + token_ids=list(seq.token_ids[:lmc]), + block_ids=list(seq.block_table), + load_spec=ls, + ) + ) meta.lookup_requests_in_step = self._lookup_in_step self._lookup_in_step = [] # Saves: store fully computed prompt chunks. Under scheduler-side @@ -1327,13 +1394,15 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: saved=saved, blocks=len(seq.block_table), ) - meta.add_request(LMCacheReqMeta( - req_id=seq.id, - token_ids=list(seq.token_ids[:aligned]), - block_ids=list(seq.block_table), - save_spec=SaveSpec(skip_leading_tokens=saved, can_save=True), - is_last_prefill=is_last_prefill, - )) + meta.add_request( + LMCacheReqMeta( + req_id=seq.id, + token_ids=list(seq.token_ids[:aligned]), + block_ids=list(seq.block_table), + save_spec=SaveSpec(skip_leading_tokens=saved, can_save=True), + is_last_prefill=is_last_prefill, + ) + ) entry[1] = aligned self._save_inflight.add(sid) self._reqs_need_recv.clear() diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 9f4009bbee..9f315e1c95 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -91,11 +91,9 @@ def __init__(self, kv_caches: dict) -> None: self._tls = threading.local() self._native_stitch = None self._native_split = None - if ( - self.layout == "segment_indexed" - and os.environ.get("OFFLOAD_NATIVE_STITCH", "0").lower() - not in ("0", "false", "no", "off") - ): + if self.layout == "segment_indexed" and os.environ.get( + "OFFLOAD_NATIVE_STITCH", "0" + ).lower() not in ("0", "false", "no", "off"): try: from atom.kv_transfer.offload import native_stitch @@ -237,13 +235,9 @@ def stitch_chunk_buffers( src_bases_by_chunk = [ self._segment_bases(nblocks) for nblocks in chunk_block_counts ] - for seg_idx, (dst_base, nb) in enumerate( - zip(dst_bases, self._seg_block_bytes) - ): + for seg_idx, (dst_base, nb) in enumerate(zip(dst_bases, self._seg_block_bytes)): parts = [ - src[ - bases[seg_idx] : bases[seg_idx] + nblocks * nb - ] + src[bases[seg_idx] : bases[seg_idx] + nblocks * nb] for src, bases, nblocks in zip( chunk_buffers, src_bases_by_chunk, chunk_block_counts ) @@ -279,18 +273,14 @@ def split_request_buffer( dst_bases_by_chunk = [ self._segment_bases(nblocks) for nblocks in chunk_block_counts ] - for seg_idx, (src_base, nb) in enumerate( - zip(src_bases, self._seg_block_bytes) - ): + for seg_idx, (src_base, nb) in enumerate(zip(src_bases, self._seg_block_bytes)): logical_block_start = 0 for dst, bases, nblocks in zip( chunk_buffers, dst_bases_by_chunk, chunk_block_counts ): nbytes = nblocks * nb if nbytes: - dst[ - bases[seg_idx] : bases[seg_idx] + nbytes - ].copy_( + dst[bases[seg_idx] : bases[seg_idx] + nbytes].copy_( src[ src_base + logical_block_start * nb : src_base @@ -336,9 +326,7 @@ def gpu_to_host( stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() with stream_ctx: if self.layout == "segment_indexed": - idx = torch.tensor( - block_ids, dtype=torch.long, device=self._device - ) + idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) bases = self._segment_bases(len(block_ids)) for seg, base, nb in zip( self._segments, bases, self._seg_block_bytes @@ -391,9 +379,7 @@ def host_to_gpu( stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() with stream_ctx: if self.layout == "segment_indexed": - idx = torch.tensor( - block_ids, dtype=torch.long, device=self._device - ) + idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) bases = self._segment_bases(len(block_ids)) for seg, base, nb in zip( self._segments, bases, self._seg_block_bytes diff --git a/atom/kv_transfer/offload/native_stitch.py b/atom/kv_transfer/offload/native_stitch.py index 5a75c20522..65f40d0450 100644 --- a/atom/kv_transfer/offload/native_stitch.py +++ b/atom/kv_transfer/offload/native_stitch.py @@ -14,7 +14,6 @@ from torch.utils.cpp_extension import load - _EXT = None @@ -35,7 +34,9 @@ def load_extension() -> None: _load_ext() -def stitch_chunk_buffers(dst, chunk_buffers, chunk_block_counts, seg_block_bytes) -> None: +def stitch_chunk_buffers( + dst, chunk_buffers, chunk_block_counts, seg_block_bytes +) -> None: _load_ext().stitch_chunk_buffers( dst, chunk_buffers, @@ -44,7 +45,9 @@ def stitch_chunk_buffers(dst, chunk_buffers, chunk_block_counts, seg_block_bytes ) -def split_request_buffer(src, chunk_buffers, chunk_block_counts, seg_block_bytes) -> None: +def split_request_buffer( + src, chunk_buffers, chunk_block_counts, seg_block_bytes +) -> None: _load_ext().split_request_buffer( src, chunk_buffers, diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index db738745df..a397f28c99 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -60,6 +60,7 @@ def _scheduler() -> LMCacheOffloadConnectorScheduler: @pytest.mark.parametrize("layout", ["segment", "segment_indexed"]) def test_segment_major_codec_roundtrip_noncontiguous_blocks(monkeypatch, layout): import torch + if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") @@ -120,6 +121,7 @@ def test_segment_major_codec_roundtrip_noncontiguous_blocks(monkeypatch, layout) def test_segment_indexed_stitches_chunk_buffers(monkeypatch): import torch + if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") @@ -165,6 +167,7 @@ def test_segment_indexed_stitches_chunk_buffers(monkeypatch): @pytest.mark.parametrize("method_name", ["gpu_to_host", "host_to_gpu"]) def test_codec_rejects_invalid_block_ids_before_copy(monkeypatch, layout, method_name): import torch + if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") @@ -190,6 +193,7 @@ def test_codec_rejects_invalid_block_ids_before_copy(monkeypatch, layout, method def test_codec_rejects_short_host_buffer(monkeypatch): import torch + if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") @@ -211,6 +215,7 @@ def test_codec_rejects_short_host_buffer(monkeypatch): def test_copy_stream_is_cached_per_codec_device(monkeypatch): import torch + if not hasattr(torch, "device") or not hasattr(torch, "cuda"): pytest.skip("torch cuda API is unavailable") From bdbe0c1c0bd958268c7c18c41782f17e70e69ced Mon Sep 17 00:00:00 2001 From: yihonglie Date: Tue, 2 Jun 2026 03:26:07 -0500 Subject: [PATCH 10/27] Clarify parked offload prefill naming --- atom/model_engine/scheduler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index ccb5156800..d3c790442e 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -1474,7 +1474,7 @@ def _park_ready_offload_partial_prefills(self) -> None: ): return - ready: deque[Sequence] = deque() + parked: deque[Sequence] = deque() keep_running: deque[Sequence] = deque() while self.running: seq = self.running.popleft() @@ -1485,17 +1485,17 @@ def _park_ready_offload_partial_prefills(self) -> None: seq.is_partial_prefill = False self._partial_prefill_count -= 1 seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS - ready.append(seq) + parked.append(seq) else: keep_running.append(seq) self.running = keep_running - if ready: + if parked: offload_trace( "scheduler_park_partial_prefill_for_load", - reqs=[seq.id for seq in ready], + reqs=[seq.id for seq in parked], ) - self.waiting.extendleft(reversed(ready)) + self.waiting.extendleft(reversed(parked)) def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """Reconcile scheduler state with completed KV transfers. From da02bddc1b7335e167863d92e67f584ae7f2255d Mon Sep 17 00:00:00 2001 From: yihonglie Date: Tue, 2 Jun 2026 12:01:19 -0500 Subject: [PATCH 11/27] Support max completion tokens in OpenAI API --- atom/entrypoints/openai/api_server.py | 83 +++++++++++++++++++- atom/entrypoints/openai/protocol.py | 18 +++++ tests/entrypoints/test_api_server_helpers.py | 34 ++++++++ tests/entrypoints/test_protocol.py | 34 ++++++++ 4 files changed, 165 insertions(+), 4 deletions(-) diff --git a/atom/entrypoints/openai/api_server.py b/atom/entrypoints/openai/api_server.py index 9ddbc2414e..6e253468c5 100644 --- a/atom/entrypoints/openai/api_server.py +++ b/atom/entrypoints/openai/api_server.py @@ -165,6 +165,43 @@ def _coerce_n(requested_n: Optional[int], temperature: Optional[float]) -> int: return n +def _validate_context_length( + num_prompt_tokens: int, + max_tokens: int, + max_model_len: Optional[int], +) -> None: + if max_model_len is None: + return + + requested_output_tokens = max(0, int(max_tokens or 0)) + total_tokens = int(num_prompt_tokens) + requested_output_tokens + if total_tokens <= int(max_model_len): + return + + raise ValueError( + f"This model's maximum context length is {max_model_len} tokens. " + f"However, you requested {requested_output_tokens} output tokens and " + f"your prompt contains at least {num_prompt_tokens} input tokens, for " + f"a total of at least {total_tokens} tokens. Please reduce the length " + f"of the input prompt or the number of requested output tokens." + ) + + +def _get_engine_max_model_len() -> Optional[int]: + config = getattr(engine, "config", None) + if config is None: + config = getattr(getattr(engine, "io_processor", None), "config", None) + return getattr(config, "max_model_len", None) + + +def _validate_sequence_context_length(seq) -> None: + _validate_context_length( + seq.num_prompt_tokens, + seq.max_tokens, + _get_engine_max_model_len(), + ) + + def _has_multimodal_content(messages: List[Any]) -> bool: for message in messages: content = getattr(message, "content", None) @@ -369,6 +406,11 @@ def do_preprocess(): ) seq = await loop.run_in_executor(None, do_preprocess) + try: + _validate_sequence_context_length(seq) + except Exception: + engine.io_processor.requests.pop(seq.id, None) + raise engine.core_mgr.add_request([seq]) while True: @@ -454,6 +496,11 @@ def do_preprocess(): ) seq = await loop.run_in_executor(None, do_preprocess) + try: + _validate_sequence_context_length(seq) + except Exception: + engine.io_processor.requests.pop(seq.id, None) + raise engine.core_mgr.add_request([seq]) while True: @@ -553,6 +600,12 @@ def do_preprocess(): ) seqs = await loop.run_in_executor(None, do_preprocess) + try: + _validate_sequence_context_length(seqs[0]) + except Exception: + for seq in seqs: + engine.io_processor.requests.pop(seq.id, None) + raise engine.core_mgr.add_request(seqs) num_tokens_input = seqs[0].num_prompt_tokens @@ -649,7 +702,18 @@ def do_preprocess(): _seq_id_to_request_id[seq.id] = request_id return seq - seq = await executor_loop.run_in_executor(None, do_preprocess) + seq = None + try: + seq = await executor_loop.run_in_executor(None, do_preprocess) + _validate_sequence_context_length(seq) + except Exception: + _stream_queues.pop(request_id, None) + _stream_loops.pop(request_id, None) + _request_start_times.pop(request_id, None) + if seq is not None: + _seq_id_to_request_id.pop(seq.id, None) + engine.io_processor.requests.pop(seq.id, None) + raise seq_id = seq.id logger.info(f"API: Created request_id={request_id}, seq_id={seq_id}") @@ -723,7 +787,18 @@ def do_preprocess(): _seq_id_to_request_id[seq.id] = request_id return seqs - seqs = await executor_loop.run_in_executor(None, do_preprocess) + seqs = [] + try: + seqs = await executor_loop.run_in_executor(None, do_preprocess) + _validate_sequence_context_length(seqs[0]) + except Exception: + _stream_queues.pop(request_id, None) + _stream_loops.pop(request_id, None) + _request_start_times.pop(request_id, None) + for seq in seqs: + _seq_id_to_request_id.pop(seq.id, None) + engine.io_processor.requests.pop(seq.id, None) + raise seq_ids = [seq.id for seq in seqs] logger.info( f"API: Created fan-out request_id={request_id}, n={n}, seq_ids={seq_ids}" @@ -802,7 +877,7 @@ async def chat_completions(request: ChatCompletionRequest): effective_n = _coerce_n(request.n, request.temperature) sampling_params = _build_sampling_params( temperature=request.temperature, - max_tokens=request.max_tokens, + max_tokens=request.get_max_tokens(), stop_strings=request.stop, ignore_eos=request.ignore_eos, top_k=request.top_k, @@ -931,7 +1006,7 @@ async def completions(request: CompletionRequest): effective_n = _coerce_n(request.n, request.temperature) sampling_params = _build_sampling_params( temperature=request.temperature, - max_tokens=request.max_tokens, + max_tokens=request.get_max_tokens(), stop_strings=request.stop, ignore_eos=request.ignore_eos, top_k=request.top_k, diff --git a/atom/entrypoints/openai/protocol.py b/atom/entrypoints/openai/protocol.py index 5e43382a5b..85f56ca02d 100644 --- a/atom/entrypoints/openai/protocol.py +++ b/atom/entrypoints/openai/protocol.py @@ -75,6 +75,7 @@ class ChatCompletionRequest(BaseModel): top_k: Optional[int] = DEFAULT_TOP_K top_p: Optional[float] = DEFAULT_TOP_P max_tokens: Optional[int] = DEFAULT_MAX_TOKENS + max_completion_tokens: Optional[int] = None stop: Optional[List[str]] = None ignore_eos: Optional[bool] = False stream: Optional[bool] = False @@ -90,6 +91,14 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = 0.0 n: Optional[int] = 1 + def get_max_tokens(self) -> int: + """Return the effective generation cap for OpenAI chat requests.""" + if self.max_completion_tokens is not None: + return self.max_completion_tokens + if self.max_tokens is not None: + return self.max_tokens + return DEFAULT_MAX_TOKENS + def get_messages(self) -> List[ChatMessage]: """Get messages from either 'messages' or 'prompt' field.""" if self.messages is not None: @@ -111,6 +120,7 @@ class CompletionRequest(BaseModel): top_k: Optional[int] = DEFAULT_TOP_K top_p: Optional[float] = DEFAULT_TOP_P max_tokens: Optional[int] = DEFAULT_MAX_TOKENS + max_completion_tokens: Optional[int] = None stop: Optional[List[str]] = None ignore_eos: Optional[bool] = False stream: Optional[bool] = False @@ -118,6 +128,14 @@ class CompletionRequest(BaseModel): kv_transfer_params: Optional[Dict[str, Any]] = None n: Optional[int] = 1 + def get_max_tokens(self) -> int: + """Return the effective generation cap for completion requests.""" + if self.max_completion_tokens is not None: + return self.max_completion_tokens + if self.max_tokens is not None: + return self.max_tokens + return DEFAULT_MAX_TOKENS + # ============================================================================ # Response Models diff --git a/tests/entrypoints/test_api_server_helpers.py b/tests/entrypoints/test_api_server_helpers.py index 7811f6276a..43368b0252 100644 --- a/tests/entrypoints/test_api_server_helpers.py +++ b/tests/entrypoints/test_api_server_helpers.py @@ -165,3 +165,37 @@ def test_invalid_n_rejected_by_sampling_params(self): ignore_eos=False, n=0, ) + + +class TestValidateContextLength: + """Oversized OpenAI requests should fail before entering the scheduler.""" + + def test_equal_to_max_model_len_is_allowed(self): + api_server._validate_context_length( + num_prompt_tokens=120, + max_tokens=8, + max_model_len=128, + ) + + def test_total_over_max_model_len_is_rejected(self): + with pytest.raises(ValueError, match="maximum context length is 128"): + api_server._validate_context_length( + num_prompt_tokens=121, + max_tokens=8, + max_model_len=128, + ) + + def test_prompt_alone_over_max_model_len_is_rejected(self): + with pytest.raises(ValueError, match="prompt contains at least 129"): + api_server._validate_context_length( + num_prompt_tokens=129, + max_tokens=0, + max_model_len=128, + ) + + def test_missing_max_model_len_skips_validation(self): + api_server._validate_context_length( + num_prompt_tokens=129, + max_tokens=8, + max_model_len=None, + ) diff --git a/tests/entrypoints/test_protocol.py b/tests/entrypoints/test_protocol.py index 41e899c71a..321c2c2f02 100644 --- a/tests/entrypoints/test_protocol.py +++ b/tests/entrypoints/test_protocol.py @@ -171,11 +171,33 @@ def test_defaults(self): ) assert req.temperature == 1.0 assert req.max_tokens == 8192 + assert req.get_max_tokens() == 8192 assert req.stream is False assert req.top_p == 1.0 assert req.top_k == -1 assert req.n == 1 + def test_max_completion_tokens_sets_effective_limit(self): + req = ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hi"}], + "max_completion_tokens": 16, + } + ) + assert req.max_tokens == 8192 + assert req.max_completion_tokens == 16 + assert req.get_max_tokens() == 16 + + def test_max_tokens_still_sets_effective_limit(self): + req = ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hi"}], + "max_tokens": 32, + } + ) + assert req.max_tokens == 32 + assert req.get_max_tokens() == 32 + def test_n_greater_than_one(self): req = ChatCompletionRequest.model_validate( { @@ -213,8 +235,20 @@ def test_basic_request(self): req = CompletionRequest(prompt="Hello world") assert req.prompt == "Hello world" assert req.max_tokens == 8192 + assert req.get_max_tokens() == 8192 assert req.n == 1 + def test_max_completion_tokens_sets_effective_limit(self): + req = CompletionRequest.model_validate( + { + "prompt": "Hello world", + "max_completion_tokens": 16, + } + ) + assert req.max_tokens == 8192 + assert req.max_completion_tokens == 16 + assert req.get_max_tokens() == 16 + def test_extra_fields_ignored(self): req = CompletionRequest.model_validate( {"prompt": "Hello", "unknown": "ignored"} From d23a9f4639753bd95953d4611da3a20d43386f0a Mon Sep 17 00:00:00 2001 From: yihonglie Date: Wed, 3 Jun 2026 02:37:30 -0500 Subject: [PATCH 12/27] Add LMCache-compatible offload connector --- atom/kv_transfer/offload/connector.py | 663 ++++----------------- atom/kv_transfer/offload/gpu_connector.py | 72 +++ atom/kv_transfer/offload/lmcache_compat.py | 373 ++++++++++++ tests/test_lmcache_offload_connector.py | 364 +++++++++++ 4 files changed, 937 insertions(+), 535 deletions(-) create mode 100644 atom/kv_transfer/offload/lmcache_compat.py diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index d2daa69bff..0753bb26f2 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -3,15 +3,16 @@ """ATOM standalone LMCache CPU/NVMe KV-offload connector. -Design (see ../../../../PLAN_impl_lmcache_offload_v5.md + 005 LEARN notes): - -* **Reuse real LMCache as a storage tier only** — per-rank ``LMCacheEngine`` for its - ``StorageManager`` (CPU LRU + NVMe L3) + ``ChunkedTokenDatabase`` (chunk-256 keys). - We bypass ``engine.store/retrieve`` (its token-major GPU path can't represent ATOM's - x-packed KV storage layout — ``K=(nb,H,D//x,bs,x)``, see ``ATOMKVByteCodec`` docstring; - loosely "swizzle", but a persistent storage layout, not LDS bank-swizzle) and instead - move **opaque per-block bytes** via :class:`ATOMKVByteCodec` into pinned - ``KV_2LTD``-as-uint8 ``MemoryObj``s. +Design: + +* **Use LMCache engine orchestration** — worker-side save/load calls + ``CacheEngine.store()`` / ``CacheEngine.retrieve()`` so LMCache owns chunking, + key generation, lookup pins, and storage-manager put/get. +* **ATOM-owned raw-byte GPU connector** — LMCache's stock vLLM GPU connectors + cannot represent ATOM's x-packed AITER KV layout + (``K=(nb,H,D//x,bs,x)``). We pass an ATOM ``GPUConnectorInterface`` + implementation that moves opaque per-block bytes with + :class:`ATOMKVByteCodec`. * **Daemon-after-forward copies** — ``start_load_kv`` only ``submit``s to a single serial copy daemon (ThreadPoolExecutor max_workers=1) and returns immediately, so the worker RPC thread is free for ``forward``; completions are polled in @@ -38,6 +39,10 @@ from atom.kv_transfer.disaggregation.types import KVConnectorOutput, ReqId from atom.kv_transfer.offload import config as offcfg from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec +from atom.kv_transfer.offload.lmcache_compat import ( + ATOMLMCacheGPUConnector, + ATOMRawBytesLMCacheMetadata, +) from atom.kv_transfer.offload.metadata import ( LMCacheOffloadMetadata, LMCacheReqMeta, @@ -49,30 +54,6 @@ logger = logging.getLogger("atom") -def _cdiv(a: int, b: int) -> int: - return -(-a // b) - - -class _UnusedGPUConnector: - """Satisfies LMCacheEngineBuilder.get_or_create; never invoked (we do our own - byte-copy and never call engine.store/retrieve).""" - - def to_gpu(self, *a, **k): - raise NotImplementedError - - def from_gpu(self, *a, **k): - raise NotImplementedError - - def batched_from_gpu(self, *a, **k): - raise NotImplementedError - - def batched_to_gpu(self, *a, **k): - raise NotImplementedError - - def get_shape(self, num_tokens): - return torch.Size((num_tokens,)) - - # ===================================================================== # Worker side # ===================================================================== @@ -97,8 +78,8 @@ def __init__(self, config) -> None: # parked seq is waiting for it) never queues behind a backlog of fire- # and-forget saves (Phase 4 root cause: with one shared serial daemon, a # reload sat behind ~N filler saves -> request hung well past timeout). - # Each worker thread gets its OWN CUDA stream (disjoint block_ids -> no - # write conflict). OFFLOAD_COPY_WORKERS tunes the SAVE pool only. + # The LMCache-compatible GPU connector owns per-thread staging streams. + # OFFLOAD_COPY_WORKERS tunes the SAVE pool only. n_save_workers = int(os.environ.get("OFFLOAD_COPY_WORKERS", "1")) self._load_executor = ThreadPoolExecutor( max_workers=1, thread_name_prefix="lmc-offload-load" @@ -111,10 +92,6 @@ def __init__(self, config) -> None: self._done_load: set[ReqId] = set() self._done_save: set[ReqId] = set() self._failed_load: set[ReqId] = set() - self._load_active = threading.Event() - self._request_fastpath = os.environ.get( - "OFFLOAD_REQUEST_FASTPATH", "1" - ).lower() not in ("0", "false", "no", "off") self._engine = None self._sm = None @@ -126,6 +103,7 @@ def __init__(self, config) -> None: def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: from aiter.dist.parallel_state import get_tp_group from lmcache.v1.cache_engine import LMCacheEngineBuilder + from lmcache.v1.memory_management import MemoryFormat tp = get_tp_group() rank, world = tp.rank_in_group, tp.world_size @@ -135,21 +113,31 @@ def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: offcfg.apply_extra_overrides( cfg, getattr(self._config, "kv_transfer_config", None) ) - meta = offcfg.build_lmcache_metadata(self._config, cfg, world, rank) self.chunk_size = int(cfg.chunk_size) + self._codec = ATOMKVByteCodec(kv_caches) + base_meta = offcfg.build_lmcache_metadata(self._config, cfg, world, rank) + meta = ATOMRawBytesLMCacheMetadata( + base_meta, + atom_block_size=self.block_size, + bytes_per_block=self._codec.bytes_per_block, + ) + gpu_connector = ATOMLMCacheGPUConnector(self._codec, self.block_size) self._engine = LMCacheEngineBuilder.get_or_create( f"atom-offload-{rank}", cfg, meta, - _UnusedGPUConnector(), + gpu_connector, lambda t, s: None, lambda o, s: o, ) + # LMCache's LocalCPU allocator does not accept BINARY for normal + # MemoryObj allocation. The metadata shape/dtype already make this an + # opaque uint8 object, so keep a supported tensor MemoryFormat. + self._engine.fmt = MemoryFormat.KV_2LTD self._engine.post_init() self._sm = self._engine.storage_manager self._tdb = self._engine.token_database - self._codec = ATOMKVByteCodec(kv_caches) # DEBUG: wrap engine.lookup to capture EVERY call (incl. the ones the ZMQ # lookup_server makes on behalf of the scheduler) — args + result. @@ -219,12 +207,6 @@ def start_load_kv(self, metadata) -> None: self._save_executor.submit(self._guard, "save", self._do_save_req, req) def _guard(self, kind: str, fn, req) -> None: - load_active = getattr(self, "_load_active", None) - if kind == "load" and load_active is None: - load_active = threading.Event() - self._load_active = load_active - if kind == "load": - load_active.set() try: fn(req) except Exception: @@ -240,9 +222,6 @@ def _guard(self, kind: str, fn, req) -> None: # A failed save should not keep blocks pinned forever. The # request simply loses this offload opportunity. self._done_save.add(req.req_id) - finally: - if kind == "load": - load_active.clear() def _lookup_unpin(self, req_id) -> None: if getattr(self, "_engine", None) is None: @@ -280,34 +259,6 @@ def _stream(self) -> torch.cuda.Stream: streams[key] = s return s - def _host_tmp(self, nbytes: int) -> torch.Tensor: - """Pinned CPU scratch buffer owned by the calling copy-daemon thread.""" - buf = getattr(self._tls, "host_tmp", None) - if buf is None or int(buf.numel()) < int(nbytes): - try: - buf = torch.empty((int(nbytes),), dtype=torch.uint8, pin_memory=True) - except RuntimeError: - logger.warning( - "LMCache offload: pinned host scratch allocation failed; " - "falling back to pageable CPU memory", - exc_info=True, - ) - buf = torch.empty((int(nbytes),), dtype=torch.uint8) - self._tls.host_tmp = buf - return buf[: int(nbytes)] - - def _pause_save_for_load(self, stream: torch.cuda.Stream) -> None: - """Let critical-path loads drain before fire-and-forget save copies.""" - load_active = getattr(self, "_load_active", None) - if load_active is None or not load_active.is_set(): - return - stream.synchronize() - while load_active.is_set(): - time.sleep(0.001) - - def _block_ids(self, req: LMCacheReqMeta, start: int, end: int) -> list[int]: - return req.block_ids[start // self.block_size : _cdiv(end, self.block_size)] - def _profile_enabled(self) -> bool: return os.environ.get("OFFLOAD_PROFILE", "1").lower() not in ( "0", @@ -316,56 +267,24 @@ def _profile_enabled(self) -> bool: "off", ) - def _request_fastpath_enabled(self) -> bool: - return ( - bool(getattr(self, "_request_fastpath", False)) - and self._codec is not None - and self._codec.layout == "segment_indexed" - ) - - def _request_level_key(self, chunks, token_count: int): - """Synthetic key for a whole-prefix segment-major object. - - The normal LMCache chunk keys remain authoritative for lookup. This key - is an optional per-rank fast path for exact full-prefix reloads: it uses - the last chunk's prefix hash plus tags, so it cannot collide with normal - chunk entries and stays stable across scheduler/worker processes. - """ - if not chunks: - return None - key = chunks[-1][2] - request_configs = dict(getattr(key, "request_configs", None) or {}) - request_configs["lmcache.tag.atom_offload"] = "request" - request_configs["lmcache.tag.atom_offload_tokens"] = str(int(token_count)) - request_configs["lmcache.tag.atom_offload_layout"] = str( - getattr(self._codec, "layout", "unknown") - ) - return key.__class__( - model_name=key.model_name, - world_size=key.world_size, - worker_id=key.worker_id, - chunk_hash=key.chunk_hash, - dtype=key.dtype, - request_configs=request_configs, - ) - # -- copy daemon thread ---------------------------------------------- def _do_load_req(self, req: LMCacheReqMeta) -> None: ls = req.load_spec assert ls is not None hbm = int(ls.hbm_cached_tokens) - toks = req.token_ids[: ls.lmcache_cached_tokens] + lmc = int(ls.lmcache_cached_tokens) + toks = req.token_ids[:lmc] t_total0 = time.perf_counter() offload_trace( "worker_load_start", rank=getattr(self, "_rank", "?"), req=req.req_id, hbm=hbm, - lmc=ls.lmcache_cached_tokens, + lmc=lmc, toks=len(toks), blocks=len(req.block_ids), ) - if int(ls.lmcache_cached_tokens) <= hbm: + if lmc <= hbm: self._lookup_unpin(req.req_id) with self._lock: self._done_load.add(req.req_id) @@ -399,281 +318,55 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", ) return - stream = self._stream() + mask = torch.ones(len(toks), dtype=torch.bool) mask[:hbm] = False - t0 = time.perf_counter() - chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) - process_ms = (time.perf_counter() - t0) * 1000 - logger.debug( - "offload _do_load req=%s hbm=%d lmc=%d chunks=%d", - req.req_id, - hbm, - ls.lmcache_cached_tokens, - len(chunks), - ) - - # All-or-nothing above the HBM prefix: a partial load would let attention - # read uninitialized blocks, and a chunk that overlaps an HBM-cache hit - # could overwrite shared prefix-cache blocks. In either case the seq - # wakes and re-prefills from its HBM floor. - if not chunks: - logger.warning( - "LMCache offload: no loadable chunks req=%s; re-prefill", req.req_id - ) - self._lookup_unpin(req.req_id) - with self._lock: - self._failed_load.add(req.req_id) - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status="no_chunks", - total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", - ) - return - for s, _e, _key in chunks: - if s < hbm: - logger.warning( - "LMCache offload: chunk overlaps HBM prefix req=%s hbm=%d " - "chunk_start=%d; re-prefill", - req.req_id, - hbm, - s, - ) - self._lookup_unpin(req.req_id) - with self._lock: - self._failed_load.add(req.req_id) - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status="overlap_hbm", - hbm=hbm, - chunk_start=s, - total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", - ) - return - contains_ms = 0.0 - loaded_objs = [] - get_ms = 0.0 - host_alloc_ms = 0.0 - stitch_ms = 0.0 - h2d_submit_ms = 0.0 - sync_ms = 0.0 - nblocks = 0 - nbytes = 0 - copy_calls = 0 - chunk_bids: list[list[int]] = [ - self._block_ids(req, s, e) for (s, e, _key) in chunks - ] - all_bids = [bid for bids in chunk_bids for bid in bids] - nblocks = len(all_bids) - nbytes = nblocks * self._codec.bytes_per_block - - request_key = None - if hbm == 0 and self._request_fastpath_enabled(): - request_key = self._request_level_key(chunks, len(toks)) - if request_key is not None: - req_mo = None - t0 = time.perf_counter() - request_location = self._sm.contains(request_key) - contains_ms += (time.perf_counter() - t0) * 1000 - if request_location: - try: - t0 = time.perf_counter() - req_mo = self._sm.get(request_key) - get_ms += (time.perf_counter() - t0) * 1000 - if req_mo is not None: - copy_calls = self._codec.copy_calls_for_block_ids(all_bids) - t0 = time.perf_counter() - self._codec.host_to_gpu(req_mo.tensor, all_bids, stream) - h2d_submit_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - self._lookup_unpin(req.req_id) - with self._lock: - self._done_load.add(req.req_id) - total_ms = (time.perf_counter() - t_total0) * 1000 - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status="ok_request", - chunks=len(chunks), - blocks=nblocks, - bytes_gib=f"{nbytes / 1024**3:.3f}", - stitch_ms=f"{stitch_ms:.2f}", - h2d_submit_ms=f"{h2d_submit_ms:.2f}", - sync_ms=f"{sync_ms:.2f}", - total_ms=f"{total_ms:.2f}", - ) - if self._profile_enabled(): - logger.info( - "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " - "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " - "layout=%s fastpath=request process_ms=%.2f " - "contains_ms=%.2f get_ms=%.2f host_alloc_ms=%.2f " - "stitch_ms=%.2f h2d_submit_ms=%.2f sync_ms=%.2f " - "total_ms=%.2f", - getattr(self, "_rank", "?"), - req.req_id, - hbm, - ls.lmcache_cached_tokens, - len(chunks), - nblocks, - nbytes / 1024**3, - copy_calls, - self._codec.layout, - process_ms, - contains_ms, - get_ms, - host_alloc_ms, - stitch_ms, - h2d_submit_ms, - sync_ms, - total_ms, - ) - logger.info("offload _do_load DONE req=%s", req.req_id) - return - finally: - if req_mo is not None: - req_mo.ref_count_down() - - for _s, _e, key in chunks: - t0 = time.perf_counter() - if not self._sm.contains(key): - contains_ms += (time.perf_counter() - t0) * 1000 - logger.warning( - "LMCache offload: load miss req=%s; re-prefill", req.req_id - ) - self._lookup_unpin(req.req_id) - with self._lock: - self._failed_load.add(req.req_id) - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status="miss", - chunks=len(chunks), - total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", - ) - return - contains_ms += (time.perf_counter() - t0) * 1000 - try: - for s, e, key in chunks: - t0 = time.perf_counter() - mo = self._sm.get(key) - get_ms += (time.perf_counter() - t0) * 1000 - if mo is None: - t0 = time.perf_counter() - stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - for loaded_mo in loaded_objs: - loaded_mo.ref_count_down() - self._lookup_unpin(req.req_id) - with self._lock: - self._failed_load.add(req.req_id) - offload_trace( - "worker_load_done", - rank=getattr(self, "_rank", "?"), - req=req.req_id, - status="get_none", - chunks=len(chunks), - total_ms=f"{(time.perf_counter() - t_total0) * 1000:.2f}", - ) - return - loaded_objs.append(mo) - bids = chunk_bids[len(loaded_objs) - 1] - if self._codec.layout != "segment_indexed": - copy_calls += self._codec.copy_calls_for_block_ids(bids) - t0 = time.perf_counter() - self._codec.host_to_gpu(mo.tensor, bids, stream) - h2d_submit_ms += (time.perf_counter() - t0) * 1000 - if self._codec.layout == "segment_indexed": - copy_calls = self._codec.copy_calls_for_block_ids(all_bids) - t0 = time.perf_counter() - req_buf = self._host_tmp(nbytes) - host_alloc_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - self._codec.stitch_chunk_buffers( - req_buf, - [mo.tensor for mo in loaded_objs], - [len(bids) for bids in chunk_bids], - ) - stitch_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - self._codec.host_to_gpu(req_buf, all_bids, stream) - h2d_submit_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - except Exception: - try: - t0 = time.perf_counter() - stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - finally: - for loaded_mo in loaded_objs: - loaded_mo.ref_count_down() - self._lookup_unpin(req.req_id) - raise - for mo in loaded_objs: - mo.ref_count_down() - # Release the lookup pin (taken by the scheduler's LookupClient.lookup) - # now that the chunks are safely in GPU; lets the pool evict them later. + t_retrieve0 = time.perf_counter() + ret_mask = self._engine.retrieve( + torch.tensor(toks), + mask=mask, + block_ids=req.block_ids, + req_id=str(req.req_id), + ) + retrieve_ms = (time.perf_counter() - t_retrieve0) * 1000 self._lookup_unpin(req.req_id) + loaded = bool(ret_mask[hbm:lmc].all().item()) if lmc > hbm else True with self._lock: - self._done_load.add(req.req_id) + if loaded: + self._done_load.add(req.req_id) + else: + self._failed_load.add(req.req_id) total_ms = (time.perf_counter() - t_total0) * 1000 offload_trace( "worker_load_done", rank=getattr(self, "_rank", "?"), req=req.req_id, - status="ok", - chunks=len(chunks), - blocks=nblocks, - bytes_gib=f"{nbytes / 1024**3:.3f}", - stitch_ms=f"{stitch_ms:.2f}", - h2d_submit_ms=f"{h2d_submit_ms:.2f}", - sync_ms=f"{sync_ms:.2f}", + status="ok" if loaded else "miss", + hbm=hbm, + lmc=lmc, + retrieved=int(ret_mask.sum().item()), + retrieve_ms=f"{retrieve_ms:.2f}", total_ms=f"{total_ms:.2f}", ) if self._profile_enabled(): logger.info( "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " - "chunks=%d blocks=%d bytes=%.3fGiB copy_calls=%d " - "layout=%s fastpath=chunk process_ms=%.2f contains_ms=%.2f " - "get_ms=%.2f host_alloc_ms=%.2f stitch_ms=%.2f " - "h2d_submit_ms=%.2f sync_ms=%.2f total_ms=%.2f", + "retrieved=%d status=%s retrieve_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, hbm, - ls.lmcache_cached_tokens, - len(chunks), - nblocks, - nbytes / 1024**3, - copy_calls, - self._codec.layout, - process_ms, - contains_ms, - get_ms, - host_alloc_ms, - stitch_ms, - h2d_submit_ms, - sync_ms, + lmc, + int(ret_mask.sum().item()), + "ok" if loaded else "miss", + retrieve_ms, total_ms, ) logger.info("offload _do_load DONE req=%s", req.req_id) def _do_save_req(self, req: LMCacheReqMeta) -> None: - from lmcache.v1.memory_management import MemoryFormat - ss = req.save_spec assert ss is not None - stream = self._stream() toks = req.token_ids if not req.is_last_prefill: toks = toks[: (len(toks) // self.chunk_size) * self.chunk_size] @@ -701,127 +394,15 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: ) mask = torch.ones(len(toks), dtype=torch.bool) mask[:skip] = False - t0 = time.perf_counter() - chunks = list(self._tdb.process_tokens(torch.tensor(toks), mask=mask)) - process_ms = (time.perf_counter() - t0) * 1000 - - keys, objs, already = [], [], 0 - request_key = None - request_obj = None - request_fastpath = "off" - if skip == 0 and self._request_fastpath_enabled() and chunks: - request_key = self._request_level_key(chunks, len(toks)) - request_fastpath = "miss" - t0 = time.perf_counter() - if self._sm.contains(request_key): - request_key = None - request_fastpath = "hit" - contains_ms = (time.perf_counter() - t0) * 1000 - else: - contains_ms = 0.0 - put_started = False - alloc_ms = 0.0 - host_alloc_ms = 0.0 - d2h_submit_ms = 0.0 - sync_ms = 0.0 - split_ms = 0.0 - put_ms = 0.0 - nblocks = 0 - total_nbytes = 0 - copy_calls = 0 - chunk_bids: list[list[int]] = [] - try: - for s, e, key in chunks: - self._pause_save_for_load(stream) - t0 = time.perf_counter() - if self._sm.contains(key): # already offloaded → skip wasted D2H - contains_ms += (time.perf_counter() - t0) * 1000 - already += 1 - continue - contains_ms += (time.perf_counter() - t0) * 1000 - bids = self._block_ids(req, s, e) - chunk_nbytes = len(bids) * self._codec.bytes_per_block - t0 = time.perf_counter() - mo = self._sm.allocate( - torch.Size((chunk_nbytes,)), torch.uint8, fmt=MemoryFormat.KV_2LTD - ) - alloc_ms += (time.perf_counter() - t0) * 1000 - if mo is None: # pool under pressure; stop here - break - keys.append(key) - objs.append(mo) - chunk_bids.append(bids) - nblocks += len(bids) - total_nbytes += chunk_nbytes - if self._codec.layout != "segment_indexed": - copy_calls += self._codec.copy_calls_for_block_ids(bids) - # D2H on this thread's dedicated copy stream (off compute stream). - t0 = time.perf_counter() - self._codec.gpu_to_host(mo.tensor, bids, stream) - d2h_submit_ms += (time.perf_counter() - t0) * 1000 - - if keys: - if self._codec.layout == "segment_indexed": - all_bids = [bid for bids in chunk_bids for bid in bids] - copy_calls = self._codec.copy_calls_for_block_ids(all_bids) - if request_key is not None and len(keys) == len(chunks): - t0 = time.perf_counter() - request_obj = self._sm.allocate( - torch.Size((total_nbytes,)), - torch.uint8, - fmt=MemoryFormat.KV_2LTD, - ) - alloc_ms += (time.perf_counter() - t0) * 1000 - if request_obj is not None: - req_buf = request_obj.tensor - request_fastpath = "stored" - else: - request_key = None - request_fastpath = "alloc_failed" - else: - request_key = None - if request_fastpath == "miss": - request_fastpath = "partial_skip" - if request_obj is None: - t0 = time.perf_counter() - req_buf = self._host_tmp(total_nbytes) - host_alloc_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - self._codec.gpu_to_host(req_buf, all_bids, stream) - d2h_submit_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - stream.synchronize() # stream-specific - sync_ms += (time.perf_counter() - t0) * 1000 - if self._codec.layout == "segment_indexed": - t0 = time.perf_counter() - self._codec.split_request_buffer( - req_buf, - [mo.tensor for mo in objs], - [len(bids) for bids in chunk_bids], - ) - split_ms += (time.perf_counter() - t0) * 1000 - put_started = True - t0 = time.perf_counter() - put_keys = list(keys) - put_objs = list(objs) - if request_key is not None and request_obj is not None: - put_keys.append(request_key) - put_objs.append(request_obj) - self._sm.batched_put(put_keys, put_objs) - put_ms += (time.perf_counter() - t0) * 1000 - except Exception: - if not put_started: - try: - t0 = time.perf_counter() - stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - finally: - cleanup_objs = list(objs) - if request_obj is not None: - cleanup_objs.append(request_obj) - for mo in cleanup_objs: - mo.ref_count_down() - raise + + t_store0 = time.perf_counter() + self._engine.store( + torch.tensor(toks), + mask=mask, + block_ids=req.block_ids, + req_id=str(req.req_id), + ) + store_ms = (time.perf_counter() - t_store0) * 1000 with self._lock: self._done_save.add(req.req_id) total_ms = (time.perf_counter() - t_total0) * 1000 @@ -831,60 +412,21 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: req=req.req_id, status="ok", toks=len(toks), - chunks=len(chunks), - stored=len(keys), - blocks=nblocks, - bytes_gib=f"{total_nbytes / 1024**3:.3f}", - d2h_submit_ms=f"{d2h_submit_ms:.2f}", - sync_ms=f"{sync_ms:.2f}", - split_ms=f"{split_ms:.2f}", - request_fastpath=request_fastpath, + skip=skip, + store_ms=f"{store_ms:.2f}", total_ms=f"{total_ms:.2f}", ) if self._profile_enabled(): logger.info( - "[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d chunks=%d " - "stored=%d already=%d blocks=%d bytes=%.3fGiB copy_calls=%d " - "layout=%s request_fastpath=%s process_ms=%.2f " - "contains_ms=%.2f alloc_ms=%.2f host_alloc_ms=%.2f " - "d2h_submit_ms=%.2f sync_ms=%.2f split_ms=%.2f " - "put_ms=%.2f total_ms=%.2f", + "[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d skip=%d " + "store_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, len(toks), - len(chunks), - len(keys), - already, - nblocks, - total_nbytes / 1024**3, - copy_calls, - self._codec.layout, - request_fastpath, - process_ms, - contains_ms, - alloc_ms, - host_alloc_ms, - d2h_submit_ms, - sync_ms, - split_ms, - put_ms, + skip, + store_ms, total_ms, ) - if logger.isEnabledFor(logging.DEBUG): - _kh = [getattr(k, "chunk_hash", None) for k in keys[:2]] - _contains = [bool(self._sm.contains(k)) for k in keys[:2]] - logger.debug( - "[OFFLOAD-SAVE] rank=%s req=%s toks=%d chunks=%d stored=%d already=%d " - "chunkhash2=%s contains=%s", - self._rank, - req.req_id, - len(toks), - len(chunks), - len(keys), - already, - _kh, - _contains, - ) # -- per-step (RPC thread, post-forward): poll completions ------------ def get_finished(self) -> KVConnectorOutput: @@ -943,6 +485,14 @@ def __init__(self, config) -> None: self._load_specs: dict[str, LoadSpec] = {} # req_id -> Sequence (queued to recv this step) self._reqs_need_recv: dict[str, object] = {} + # req_id -> HBM chunk frontier for an emitted load. If the load fails, + # lower the save frontier to this value so recomputed chunks can be + # stored again. + self._load_save_floors: dict[str, int] = {} + # req_id -> LMCache chunk frontier observed by lookup. The scheduler + # should not re-save this already-persisted prefix unless a later load + # actually fails. + self._hit_save_floors: dict[str, int] = {} # Persistent save tracker: sid -> [seq, saved_offset]. A seq's prompt # prefix is stored to LMCache once prefill computes it # (seq.prefix_hashes_published flips True), chunk by chunk. @@ -1038,6 +588,7 @@ def get_num_new_matched_tokens(self, seq) -> tuple[int, bool]: hit = int(hit) if hit == num_prompt: # full-prompt hit → recompute last token hit -= 1 + self._hit_save_floors[sid] = self._chunk_floor(hit) need = hit - int(seq.num_cached_tokens) if need <= 0: if self._lookup_client is not None: @@ -1092,13 +643,47 @@ def update_state_after_alloc(self, seq) -> None: ) # Track for save; build_connector_meta stores chunks once the scheduler's # computed frontier (seq.num_cached_tokens) has advanced past them. + # + # If LMCache lookup already found a prefix for this request, do not save + # that prefix again. This covers both direct loads and the + # hbm_satisfies_after_alloc case where HBM prefix cache already covers + # the lookup hit. Only suffix chunks computed by this request should be + # stored. + initial_saved = max( + self._lmcache_hit_save_floor(ls), + int(self._hit_save_floors.get(sid, 0)), + ) if sid not in self._save_tracker: - self._save_tracker[sid] = [seq, 0] + self._save_tracker[sid] = [seq, initial_saved] + else: + self._save_tracker[sid][0] = seq + self._save_tracker[sid][1] = max( + int(self._save_tracker[sid][1]), initial_saved + ) + + def _chunk_floor(self, tokens: int) -> int: + chunk = int(self.chunk_size or 256) + return (max(0, int(tokens)) // chunk) * chunk + + def _lmcache_hit_save_floor(self, ls: LoadSpec | None) -> int: + if ls is None: + return 0 + return self._chunk_floor(ls.lmcache_cached_tokens) + + def _set_save_frontier(self, sid: str, seq, saved: int) -> None: + saved = self._chunk_floor(saved) + if sid not in self._save_tracker: + self._save_tracker[sid] = [seq, saved] + else: + self._save_tracker[sid][0] = seq + self._save_tracker[sid][1] = saved def _clear_pending_load(self, sid: str) -> None: self._load_specs.pop(sid, None) self._reqs_need_recv.pop(sid, None) self._handoff_loads.discard(sid) + self._load_save_floors.pop(sid, None) + self._hit_save_floors.pop(sid, None) self._lookup_in_step = [ req_id for req_id in self._lookup_in_step if req_id != sid ] @@ -1349,6 +934,7 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: blocks=len(list(seq.block_table)), ) loading_sids.add(sid) + self._load_save_floors[sid] = self._chunk_floor(hbm) meta.add_request( LMCacheReqMeta( req_id=seq.id, @@ -1409,12 +995,11 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: return meta def _save_frontier(self, seq) -> int: - chunk = self.chunk_size or 256 computed = min( int(getattr(seq, "num_cached_tokens", 0)), int(getattr(seq, "num_prompt_tokens", 0)), ) - return (computed // chunk) * chunk + return self._chunk_floor(computed) def _has_pending_save(self, seq) -> bool: sid = str(seq.id) @@ -1431,7 +1016,15 @@ def save_finished(self, req_id) -> None: self._save_inflight.discard(str(req_id)) def load_failed(self, req_id) -> None: - self._clear_pending_load(str(req_id)) + sid = str(req_id) + floor = self._load_save_floors.get(sid) + entry = self._save_tracker.get(sid) + if floor is not None and entry is not None: + # The LMCache hit was not actually loaded. Let the recomputed + # [HBM, LMC) chunks be saved again instead of permanently treating + # them as already persisted. + entry[1] = self._chunk_floor(floor) + self._clear_pending_load(sid) def request_finished(self, seq) -> None: sid = str(seq.id) diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 9f315e1c95..b9aef991dd 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -209,6 +209,22 @@ def _validate_host_buf(self, host_buf: torch.Tensor, nblocks: int) -> None: f"got {int(host_buf.numel())}" ) + def _validate_device_buf(self, device_buf: torch.Tensor, nblocks: int) -> None: + if device_buf.dtype != torch.uint8: + raise TypeError("ATOMKVByteCodec: device_buf must be a uint8 tensor") + if device_buf.device != self._device: + raise TypeError( + "ATOMKVByteCodec: device_buf must be on the KV cache device " + f"{self._device}, got {device_buf.device}" + ) + required = int(nblocks) * self.bytes_per_block + if int(device_buf.numel()) < required: + raise ValueError( + "ATOMKVByteCodec: device_buf is too small " + f"for {nblocks} blocks; need {required} bytes, " + f"got {int(device_buf.numel())}" + ) + def stitch_chunk_buffers( self, dst: torch.Tensor, @@ -419,6 +435,62 @@ def host_to_gpu( non_blocking=True, ) + def gpu_to_device_buffer( + self, + device_buf: torch.Tensor, + block_ids: list[int], + stream: torch.cuda.Stream | None = None, + ) -> None: + """Gather ATOM KV blocks into a flat device staging buffer. + + The staging layout is always segment-major: + ``[seg0 blocks | seg1 blocks | ...]``. This is the layout consumed by + the LMCache-compatible connector before it copies the bytes to a + ``MemoryObj``. + """ + block_ids = self._normalize_block_ids(block_ids) + self._validate_device_buf(device_buf, len(block_ids)) + if not block_ids: + return + with self._device_ctx(): + stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with stream_ctx: + idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) + bases = self._segment_bases(len(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + mat = self._segment_bytes_matrix(seg) + dst = device_buf[ + base : base + len(block_ids) * nb + ].reshape(len(block_ids), nb) + torch.index_select(mat, 0, idx, out=dst) + + def device_buffer_to_gpu( + self, + device_buf: torch.Tensor, + block_ids: list[int], + stream: torch.cuda.Stream | None = None, + ) -> None: + """Scatter a segment-major device staging buffer into ATOM KV blocks.""" + block_ids = self._normalize_block_ids(block_ids) + self._validate_device_buf(device_buf, len(block_ids)) + if not block_ids: + return + with self._device_ctx(): + stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with stream_ctx: + idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) + bases = self._segment_bases(len(block_ids)) + for seg, base, nb in zip( + self._segments, bases, self._seg_block_bytes + ): + mat = self._segment_bytes_matrix(seg) + src = device_buf[ + base : base + len(block_ids) * nb + ].reshape(len(block_ids), nb) + mat.index_copy_(0, idx, src) + class _nullctx: def __enter__(self): diff --git a/atom/kv_transfer/offload/lmcache_compat.py b/atom/kv_transfer/offload/lmcache_compat.py new file mode 100644 index 0000000000..cb8b0d9265 --- /dev/null +++ b/atom/kv_transfer/offload/lmcache_compat.py @@ -0,0 +1,373 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""LMCache-compatible raw-byte connector for ATOM offload. + +This module lets ATOM use LMCache ``CacheEngine.store()`` / +``CacheEngine.retrieve()`` without adopting LMCache's vLLM token-major KV +layout. LMCache still owns chunking, keys, lookup pins, and storage-manager +orchestration. ATOM owns how a token range maps to AITER KV-cache blocks and +how those blocks are packed as opaque bytes. +""" + +from __future__ import annotations + +import threading +from typing import Any + +import torch + +from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec + + +def _cdiv(a: int, b: int) -> int: + return -(-int(a) // int(b)) + + +class ATOMRawBytesLMCacheMetadata: + """Proxy around ``LMCacheMetadata`` with ATOM raw-byte allocation shapes.""" + + def __init__( + self, + base_metadata: Any, + *, + atom_block_size: int, + bytes_per_block: int, + ) -> None: + self._atom_base_metadata = base_metadata + self.__dict__.update(vars(base_metadata)) + self.atom_block_size = int(atom_block_size) + self.atom_bytes_per_block = int(bytes_per_block) + chunk_size = int(getattr(base_metadata, "chunk_size")) + if self.atom_block_size <= 0: + raise ValueError("ATOM raw-byte metadata: atom_block_size must be > 0") + if self.atom_bytes_per_block <= 0: + raise ValueError("ATOM raw-byte metadata: bytes_per_block must be > 0") + if chunk_size % self.atom_block_size != 0: + raise ValueError( + "LMCache chunk size must be divisible by ATOM KV block size: " + f"chunk_size={chunk_size}, block_size={self.atom_block_size}" + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self._atom_base_metadata, name) + + def __eq__(self, other: object) -> bool: + if isinstance(other, ATOMRawBytesLMCacheMetadata): + return ( + self._atom_base_metadata == other._atom_base_metadata + and self.atom_block_size == other.atom_block_size + and self.atom_bytes_per_block == other.atom_bytes_per_block + ) + return False + + def is_first_rank(self) -> bool: + return self._atom_base_metadata.is_first_rank() + + def get_dtypes(self) -> list[torch.dtype]: + return [torch.uint8] + + def get_shapes(self, num_tokens: int | None = None) -> list[torch.Size]: + if num_tokens is None: + num_tokens = int(self.chunk_size) + nblocks = _cdiv(int(num_tokens), self.atom_block_size) + return [torch.Size((nblocks * self.atom_bytes_per_block,))] + + def get_num_groups(self) -> int: + return 1 + + +class _NullCtx: + def __enter__(self): + return None + + def __exit__(self, *args): + return False + + +class _StagingSlot: + def __init__(self, use_cuda: bool) -> None: + self.tensor: torch.Tensor | None = None + self.ready_event = None + self.free_event = None + self.free_event_valid = False + if use_cuda: + self.ready_event = torch.cuda.Event(blocking=False) + self.free_event = torch.cuda.Event(blocking=False) + + +class _ThreadTransferState: + def __init__(self, device: torch.device, use_cuda: bool) -> None: + self.device = device + self.use_cuda = use_cuda + self.pack_stream = None + self.copy_stream = None + self.next_slot = 0 + self.host_tmp: torch.Tensor | None = None + if use_cuda: + with torch.cuda.device(device): + self.pack_stream = torch.cuda.Stream() + self.copy_stream = torch.cuda.Stream() + self.slots = [_StagingSlot(use_cuda), _StagingSlot(use_cuda)] + else: + self.slots = [_StagingSlot(use_cuda), _StagingSlot(use_cuda)] + + def stream_ctx(self, stream): + if stream is None: + return _NullCtx() + return torch.cuda.stream(stream) + + +class ATOMLMCacheGPUConnector: + """LMCache GPUConnectorInterface for ATOM's opaque KV-block byte layout.""" + + def __init__(self, codec: ATOMKVByteCodec, block_size: int) -> None: + self.codec = codec + self.block_size = int(block_size) + if self.block_size <= 0: + raise ValueError("ATOM LMCache connector: block_size must be > 0") + self.device = torch.device(codec.device) + self._tls = threading.local() + + def _use_cuda(self) -> bool: + return self.device.type == "cuda" + + def _thread_state(self) -> _ThreadTransferState: + states = getattr(self._tls, "states", None) + if states is None: + states = {} + self._tls.states = states + key = str(self.device) + state = states.get(key) + if state is None: + state = _ThreadTransferState(self.device, self._use_cuda()) + states[key] = state + return state + + def _ensure_slot(self, slot: _StagingSlot, nbytes: int) -> torch.Tensor: + nbytes = int(nbytes) + if slot.tensor is None or int(slot.tensor.numel()) < nbytes: + slot.tensor = torch.empty( + (nbytes,), + dtype=torch.uint8, + device=self.device, + ) + slot.free_event_valid = False + return slot.tensor[:nbytes] + + def _next_slot(self, state: _ThreadTransferState) -> _StagingSlot: + slot = state.slots[state.next_slot % len(state.slots)] + state.next_slot += 1 + return slot + + def _ensure_host_tmp( + self, + state: _ThreadTransferState, + nbytes: int, + ) -> torch.Tensor: + nbytes = int(nbytes) + if state.host_tmp is None or int(state.host_tmp.numel()) < nbytes: + if state.use_cuda: + try: + state.host_tmp = torch.empty( + (nbytes,), + dtype=torch.uint8, + pin_memory=True, + ) + except RuntimeError: + state.host_tmp = torch.empty((nbytes,), dtype=torch.uint8) + else: + state.host_tmp = torch.empty((nbytes,), dtype=torch.uint8) + return state.host_tmp[:nbytes] + + def _memory_tensor(self, memory_obj: Any, nbytes: int) -> torch.Tensor: + tensor = getattr(memory_obj, "tensor", None) + if tensor is None and hasattr(memory_obj, "get_tensor"): + tensor = memory_obj.get_tensor(0) + if tensor is None: + raise RuntimeError("ATOM LMCache connector: invalid MemoryObj tensor") + if tensor.dtype != torch.uint8: + raise TypeError( + "ATOM LMCache connector: MemoryObj tensor must be uint8, " + f"got {tensor.dtype}" + ) + if not tensor.is_contiguous(): + raise RuntimeError("ATOM LMCache connector: MemoryObj tensor not contiguous") + flat = tensor.reshape(-1) + if int(flat.numel()) < int(nbytes): + raise ValueError( + "ATOM LMCache connector: MemoryObj tensor is too small " + f"for {nbytes} bytes; got {int(flat.numel())}" + ) + return flat[: int(nbytes)] + + def _range_block_ids( + self, + all_block_ids: list[int], + start: int, + end: int, + ) -> list[int]: + start = int(start) + end = int(end) + if start < 0 or end < start: + raise ValueError( + f"invalid LMCache token range for ATOM KV blocks: {start}:{end}" + ) + if start % self.block_size != 0: + raise ValueError( + "LMCache chunk start must be ATOM block-aligned: " + f"start={start}, block_size={self.block_size}" + ) + start_block = start // self.block_size + end_block = _cdiv(end, self.block_size) + if end_block > len(all_block_ids): + raise ValueError( + "LMCache token range exceeds ATOM block table: " + f"range={start}:{end}, needed_blocks={end_block}, " + f"available_blocks={len(all_block_ids)}" + ) + return list(all_block_ids[start_block:end_block]) + + def _ranges_to_block_ids( + self, + starts: list[int], + ends: list[int], + **kwargs, + ) -> list[list[int]]: + block_ids = kwargs.get("block_ids") + if block_ids is None: + raise ValueError("ATOM LMCache connector requires block_ids") + all_block_ids = [int(bid) for bid in block_ids] + return [ + self._range_block_ids(all_block_ids, start, end) + for start, end in zip(starts, ends, strict=True) + ] + + def from_gpu(self, memory_obj: Any, start: int, end: int, **kwargs) -> None: + self.batched_from_gpu([memory_obj], [start], [end], **kwargs) + + def to_gpu(self, memory_obj: Any, start: int, end: int, **kwargs) -> None: + self.batched_to_gpu([memory_obj], [start], [end], **kwargs) + + def batched_from_gpu( + self, + memory_objs: list[Any], + starts: list[int], + ends: list[int], + **kwargs, + ) -> None: + """Pack ATOM KV blocks to LMCache MemoryObjs via double GPU staging.""" + if not (len(memory_objs) == len(starts) == len(ends)): + raise ValueError("memory_objs, starts, and ends must have equal length") + block_id_groups = self._ranges_to_block_ids(starts, ends, **kwargs) + if not memory_objs: + return + + state = self._thread_state() + use_cuda = state.use_cuda + chunk_block_counts = [len(block_ids) for block_ids in block_id_groups] + all_block_ids = [ + block_id for block_ids in block_id_groups for block_id in block_ids + ] + total_nbytes = len(all_block_ids) * self.codec.bytes_per_block + if total_nbytes == 0: + return + + slot = self._next_slot(state) + device_buf = self._ensure_slot(slot, total_nbytes) + host_buf = self._ensure_host_tmp(state, total_nbytes) + dst_tensors = [ + self._memory_tensor( + memory_obj, + block_count * self.codec.bytes_per_block, + ) + for memory_obj, block_count in zip( + memory_objs, + chunk_block_counts, + strict=True, + ) + ] + + if use_cuda: + if slot.free_event_valid: + state.pack_stream.wait_event(slot.free_event) + with state.stream_ctx(state.pack_stream): + self.codec.gpu_to_device_buffer( + device_buf, + all_block_ids, + stream=state.pack_stream, + ) + slot.ready_event.record(state.pack_stream) + state.copy_stream.wait_event(slot.ready_event) + with state.stream_ctx(state.copy_stream): + host_buf.copy_(device_buf, non_blocking=True) + slot.free_event.record(state.copy_stream) + slot.free_event_valid = True + state.copy_stream.synchronize() + else: + self.codec.gpu_to_device_buffer(device_buf, all_block_ids) + host_buf.copy_(device_buf, non_blocking=False) + + self.codec.split_request_buffer(host_buf, dst_tensors, chunk_block_counts) + + def batched_to_gpu( + self, + memory_objs: list[Any] | None = None, + starts: list[int] | None = None, + ends: list[int] | None = None, + **kwargs, + ) -> None: + """Load LMCache MemoryObjs back into ATOM KV blocks via double staging.""" + if memory_objs is None or starts is None or ends is None: + raise ValueError("memory_objs, starts, and ends are required") + if not (len(memory_objs) == len(starts) == len(ends)): + raise ValueError("memory_objs, starts, and ends must have equal length") + block_id_groups = self._ranges_to_block_ids(starts, ends, **kwargs) + if not memory_objs: + return + + state = self._thread_state() + use_cuda = state.use_cuda + chunk_block_counts = [len(block_ids) for block_ids in block_id_groups] + all_block_ids = [ + block_id for block_ids in block_id_groups for block_id in block_ids + ] + total_nbytes = len(all_block_ids) * self.codec.bytes_per_block + if total_nbytes == 0: + return + + slot = self._next_slot(state) + device_buf = self._ensure_slot(slot, total_nbytes) + host_buf = self._ensure_host_tmp(state, total_nbytes) + src_tensors = [ + self._memory_tensor( + memory_obj, + block_count * self.codec.bytes_per_block, + ) + for memory_obj, block_count in zip( + memory_objs, + chunk_block_counts, + strict=True, + ) + ] + self.codec.stitch_chunk_buffers(host_buf, src_tensors, chunk_block_counts) + + if use_cuda: + if slot.free_event_valid: + state.copy_stream.wait_event(slot.free_event) + with state.stream_ctx(state.copy_stream): + device_buf.copy_(host_buf, non_blocking=True) + slot.ready_event.record(state.copy_stream) + state.pack_stream.wait_event(slot.ready_event) + with state.stream_ctx(state.pack_stream): + self.codec.device_buffer_to_gpu( + device_buf, + all_block_ids, + stream=state.pack_stream, + ) + slot.free_event.record(state.pack_stream) + slot.free_event_valid = True + state.pack_stream.synchronize() + else: + device_buf.copy_(host_buf, non_blocking=False) + self.codec.device_buffer_to_gpu(device_buf, all_block_ids) diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index a397f28c99..c278fb897a 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -22,6 +22,10 @@ ) from atom.kv_transfer.offload import connector as offload_connector_mod from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec +from atom.kv_transfer.offload.lmcache_compat import ( + ATOMLMCacheGPUConnector, + ATOMRawBytesLMCacheMetadata, +) from atom.model_engine.scheduler import Scheduler @@ -46,6 +50,8 @@ def _scheduler() -> LMCacheOffloadConnectorScheduler: sched._lookup_client = _LookupClient(hit=0) sched._load_specs = {} sched._reqs_need_recv = {} + sched._load_save_floors = {} + sched._hit_save_floors = {} sched._save_tracker = {} sched._save_inflight = set() sched._lookup_in_step = [] @@ -163,6 +169,144 @@ def test_segment_indexed_stitches_chunk_buffers(monkeypatch): assert torch.equal(actual, expected) +def test_raw_bytes_metadata_shapes_are_block_rounded(): + import torch + + if not hasattr(torch, "Size"): + pytest.skip("real torch is unavailable") + + base = SimpleNamespace(chunk_size=8) + base.is_first_rank = lambda: True + meta = ATOMRawBytesLMCacheMetadata( + base, + atom_block_size=4, + bytes_per_block=32, + ) + + assert meta.get_dtypes() == [torch.uint8] + assert meta.get_shapes(8) == [torch.Size((64,))] + assert meta.get_shapes(6) == [torch.Size((64,))] + assert meta.get_shapes(4) == [torch.Size((32,))] + assert meta.get_shapes() == [torch.Size((64,))] + + +def test_raw_bytes_metadata_rejects_unaligned_chunk_size(): + import torch + + if not hasattr(torch, "Size"): + pytest.skip("real torch is unavailable") + + base = SimpleNamespace(chunk_size=10) + with pytest.raises(ValueError, match="chunk size must be divisible"): + ATOMRawBytesLMCacheMetadata( + base, + atom_block_size=4, + bytes_per_block=32, + ) + + +def test_codec_device_buffer_roundtrip_noncontiguous_blocks(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "block") + original = { + "l0": SimpleNamespace( + k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), + v_cache=(torch.arange(8 * 4, dtype=torch.uint8).reshape(8, 4) + 51), + k_scale=None, + v_scale=None, + ), + "l1": SimpleNamespace( + k_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 101), + v_cache=(torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2) + 151), + k_scale=None, + v_scale=None, + ), + } + kv_caches = { + name: SimpleNamespace(k_cache=layer.k_cache.clone(), v_cache=layer.v_cache.clone()) + for name, layer in original.items() + } + for layer in kv_caches.values(): + layer.k_scale = None + layer.v_scale = None + + codec = ATOMKVByteCodec(kv_caches) + block_ids = [1, 3, 4, 7] + device_buf = torch.empty( + len(block_ids) * codec.bytes_per_block, + dtype=torch.uint8, + device=codec.device, + ) + + codec.gpu_to_device_buffer(device_buf, block_ids) + for layer in kv_caches.values(): + layer.k_cache.zero_() + layer.v_cache.zero_() + codec.device_buffer_to_gpu(device_buf, block_ids) + + for name, layer in kv_caches.items(): + src = original[name] + for bid in block_ids: + assert torch.equal(layer.k_cache[bid], src.k_cache[bid]) + assert torch.equal(layer.v_cache[bid], src.v_cache[bid]) + + +def test_lmcache_connector_maps_token_ranges_to_block_ids(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "block") + original = { + "l0": SimpleNamespace( + k_cache=torch.arange(6 * 2, dtype=torch.uint8).reshape(6, 2), + v_cache=(torch.arange(6 * 3, dtype=torch.uint8).reshape(6, 3) + 51), + k_scale=None, + v_scale=None, + ) + } + kv_caches = { + "l0": SimpleNamespace( + k_cache=original["l0"].k_cache.clone(), + v_cache=original["l0"].v_cache.clone(), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + connector = ATOMLMCacheGPUConnector(codec, block_size=4) + memory_obj = SimpleNamespace( + tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) + ) + + connector.batched_from_gpu( + [memory_obj], + [4], + [12], + block_ids=[0, 1, 2, 3, 4, 5], + ) + + kv_caches["l0"].k_cache.zero_() + kv_caches["l0"].v_cache.zero_() + connector.batched_to_gpu( + [memory_obj], + [4], + [12], + block_ids=[0, 1, 2, 3, 4, 5], + ) + + for bid in [1, 2]: + assert torch.equal(kv_caches["l0"].k_cache[bid], original["l0"].k_cache[bid]) + assert torch.equal(kv_caches["l0"].v_cache[bid], original["l0"].v_cache[bid]) + assert torch.count_nonzero(kv_caches["l0"].k_cache[0]) == 0 + assert torch.count_nonzero(kv_caches["l0"].v_cache[0]) == 0 + + @pytest.mark.parametrize("layout", ["block", "segment", "segment_indexed"]) @pytest.mark.parametrize("method_name", ["gpu_to_host", "host_to_gpu"]) def test_codec_rejects_invalid_block_ids_before_copy(monkeypatch, layout, method_name): @@ -300,13 +444,47 @@ def test_load_is_skipped_if_hbm_satisfies_after_allocation(): assert sched.should_park_for_load_after_alloc(seq) is False meta = sched.build_connector_meta() + assert meta.requests == [] assert [req for req in meta.requests if req.load_spec is not None] == [] assert seq.offload_loaded_tokens == 8 + assert sched._save_tracker[str(seq.id)][1] == 8 assert lookup.cleared == ["321"] assert str(seq.id) not in sched._load_specs assert str(seq.id) not in sched._reqs_need_recv +def test_lookup_time_hbm_satisfies_does_not_resave_hit_prefix(): + sched = _scheduler() + lookup = _LookupClient(hit=8) + sched._lookup_client = lookup + seq = SimpleNamespace( + id=322, + num_prompt_tokens=12, + token_ids=list(range(12)), + num_cached_tokens=8, + block_table=[1, 2, 3], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 0 + assert should_park is False + + sched.update_state_after_alloc(seq) + meta1 = sched.build_connector_meta() + + assert meta1.requests == [] + assert sched._save_tracker[str(seq.id)][1] == 8 + assert lookup.cleared == ["322"] + + seq.num_cached_tokens = 12 + meta2 = sched.build_connector_meta() + save_reqs = [req for req in meta2.requests if req.save_spec is not None] + + assert len(save_reqs) == 1 + assert save_reqs[0].token_ids == list(range(12)) + assert save_reqs[0].save_spec.skip_leading_tokens == 8 + + def test_load_is_skipped_if_hbm_floor_is_not_chunk_aligned(): sched = _scheduler() lookup = _LookupClient(hit=12) @@ -469,6 +647,71 @@ def test_aligned_large_hit_parks_and_emits_load_metadata(): assert lookup.cleared == [] +def test_loaded_prefix_is_not_saved_again_after_success(): + sched = _scheduler() + sched._min_load_tokens = 8 + sched._lookup_client = _LookupClient(hit=12) + seq = SimpleNamespace( + id=659, + num_prompt_tokens=16, + token_ids=list(range(16)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4], + ) + + need, should_park = sched.get_num_new_matched_tokens(seq) + assert need == 12 + assert should_park is True + + seq.num_cached_tokens = 4 + sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is True + + load_meta = sched.build_connector_meta() + assert len([req for req in load_meta.requests if req.load_spec is not None]) == 1 + assert [req for req in load_meta.requests if req.save_spec is not None] == [] + assert sched._save_tracker[str(seq.id)][1] == 12 + + seq.num_cached_tokens = 16 + save_meta = sched.build_connector_meta() + save_reqs = [req for req in save_meta.requests if req.save_spec is not None] + + assert len(save_reqs) == 1 + assert save_reqs[0].token_ids == list(range(16)) + assert save_reqs[0].save_spec.skip_leading_tokens == 12 + + +def test_load_failure_allows_recomputed_hit_range_to_be_saved(): + sched = _scheduler() + sched._min_load_tokens = 8 + sched._lookup_client = _LookupClient(hit=12) + seq = SimpleNamespace( + id=660, + num_prompt_tokens=16, + token_ids=list(range(16)), + num_cached_tokens=0, + block_table=[1, 2, 3, 4], + ) + + sched.get_num_new_matched_tokens(seq) + seq.num_cached_tokens = 4 + sched.update_state_after_alloc(seq) + assert sched.should_park_for_load_after_alloc(seq) is True + sched.build_connector_meta() + assert sched._save_tracker[str(seq.id)][1] == 12 + + sched.load_failed(seq.id) + assert sched._save_tracker[str(seq.id)][1] == 4 + + seq.num_cached_tokens = 12 + save_meta = sched.build_connector_meta() + save_reqs = [req for req in save_meta.requests if req.save_spec is not None] + + assert len(save_reqs) == 1 + assert save_reqs[0].token_ids == list(range(12)) + assert save_reqs[0].save_spec.skip_leading_tokens == 4 + + def test_worker_completes_noop_load_when_hbm_satisfies(): conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) conn._lock = threading.Lock() @@ -516,6 +759,127 @@ def test_worker_reports_unaligned_hbm_load_as_failed_without_exception(): assert conn._engine.unpinned == ["654"] +def test_worker_save_uses_lmcache_engine_store(): + import torch + + if not hasattr(torch, "tensor"): + pytest.skip("real torch is unavailable") + + class _Engine: + def __init__(self) -> None: + self.calls = [] + + def store(self, tokens, mask=None, **kwargs) -> None: + self.calls.append((tokens.tolist(), mask.tolist(), kwargs)) + + conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) + conn._lock = threading.Lock() + conn._done_save = set() + conn.chunk_size = 4 + conn._engine = _Engine() + + req = SimpleNamespace( + req_id=987, + token_ids=list(range(12)), + block_ids=[3, 4, 5], + is_last_prefill=True, + save_spec=SimpleNamespace(skip_leading_tokens=4), + ) + + conn._do_save_req(req) + + assert conn._done_save == {987} + assert len(conn._engine.calls) == 1 + tokens, mask, kwargs = conn._engine.calls[0] + assert tokens == list(range(12)) + assert mask == [False, False, False, False] + [True] * 8 + assert kwargs["block_ids"] == [3, 4, 5] + assert kwargs["req_id"] == "987" + + +def test_worker_load_uses_lmcache_engine_retrieve_and_marks_done(): + import torch + + if not hasattr(torch, "tensor"): + pytest.skip("real torch is unavailable") + + class _Engine: + def __init__(self) -> None: + self.calls = [] + self.unpinned = [] + + def retrieve(self, tokens, mask=None, **kwargs): + self.calls.append((tokens.tolist(), mask.tolist(), kwargs)) + return torch.tensor([False] * 4 + [True] * 8, dtype=torch.bool) + + def lookup_unpin(self, ids) -> None: + self.unpinned.extend(ids) + + conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) + conn._lock = threading.Lock() + conn._done_load = set() + conn._failed_load = set() + conn._done_save = set() + conn.chunk_size = 4 + conn._engine = _Engine() + + req = SimpleNamespace( + req_id=988, + token_ids=list(range(16)), + block_ids=[3, 4, 5, 6], + load_spec=SimpleNamespace(hbm_cached_tokens=4, lmcache_cached_tokens=12), + ) + + conn._do_load_req(req) + + assert conn._done_load == {988} + assert conn._failed_load == set() + assert conn._engine.unpinned == ["988"] + tokens, mask, kwargs = conn._engine.calls[0] + assert tokens == list(range(12)) + assert mask == [False, False, False, False] + [True] * 8 + assert kwargs["block_ids"] == [3, 4, 5, 6] + assert kwargs["req_id"] == "988" + + +def test_worker_load_partial_retrieve_marks_failed(): + import torch + + if not hasattr(torch, "tensor"): + pytest.skip("real torch is unavailable") + + class _Engine: + def __init__(self) -> None: + self.unpinned = [] + + def retrieve(self, tokens, mask=None, **kwargs): + return torch.tensor([False] * 4 + [True] * 4 + [False] * 4) + + def lookup_unpin(self, ids) -> None: + self.unpinned.extend(ids) + + conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) + conn._lock = threading.Lock() + conn._done_load = set() + conn._failed_load = set() + conn._done_save = set() + conn.chunk_size = 4 + conn._engine = _Engine() + + req = SimpleNamespace( + req_id=989, + token_ids=list(range(16)), + block_ids=[3, 4, 5, 6], + load_spec=SimpleNamespace(hbm_cached_tokens=4, lmcache_cached_tokens=12), + ) + + conn._do_load_req(req) + + assert conn._done_load == set() + assert conn._failed_load == {989} + assert conn._engine.unpinned == ["989"] + + def test_load_exception_is_reported_as_failed_recving(): conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) conn._lock = threading.Lock() From b0e300efdd1b41152053d4fc9a7c679ff89d1d24 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Wed, 3 Jun 2026 04:29:42 -0500 Subject: [PATCH 13/27] Add fused chunk-major LMCache staging --- atom/kv_transfer/offload/connector.py | 31 ++- atom/kv_transfer/offload/gpu_connector.py | 135 ++++++++- atom/kv_transfer/offload/lmcache_compat.py | 84 +++++- .../kv_transfer/offload/native_kv_staging.cpp | 222 +++++++++++++++ atom/kv_transfer/offload/native_kv_staging.py | 65 +++++ .../offload/native_kv_staging_kernel.hip | 145 ++++++++++ tests/test_lmcache_offload_connector.py | 263 +++++++++++++++++- 7 files changed, 926 insertions(+), 19 deletions(-) create mode 100644 atom/kv_transfer/offload/native_kv_staging.cpp create mode 100644 atom/kv_transfer/offload/native_kv_staging.py create mode 100644 atom/kv_transfer/offload/native_kv_staging_kernel.hip diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 0753bb26f2..f796640160 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -267,6 +267,24 @@ def _profile_enabled(self) -> bool: "off", ) + def _last_gpu_connector_fastpath(self) -> str: + gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None) + if gpu_connector is None or not hasattr(gpu_connector, "last_fastpath"): + return "unknown" + try: + return str(gpu_connector.last_fastpath()) + except Exception: + return "unknown" + + def _reset_gpu_connector_fastpath(self) -> None: + gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None) + if gpu_connector is None or not hasattr(gpu_connector, "reset_fastpath"): + return + try: + gpu_connector.reset_fastpath() + except Exception: + pass + # -- copy daemon thread ---------------------------------------------- def _do_load_req(self, req: LMCacheReqMeta) -> None: ls = req.load_spec @@ -323,6 +341,7 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: mask[:hbm] = False t_retrieve0 = time.perf_counter() + self._reset_gpu_connector_fastpath() ret_mask = self._engine.retrieve( torch.tensor(toks), mask=mask, @@ -330,6 +349,7 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: req_id=str(req.req_id), ) retrieve_ms = (time.perf_counter() - t_retrieve0) * 1000 + fastpath = self._last_gpu_connector_fastpath() self._lookup_unpin(req.req_id) loaded = bool(ret_mask[hbm:lmc].all().item()) if lmc > hbm else True with self._lock: @@ -346,19 +366,22 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: hbm=hbm, lmc=lmc, retrieved=int(ret_mask.sum().item()), + fastpath=fastpath, retrieve_ms=f"{retrieve_ms:.2f}", total_ms=f"{total_ms:.2f}", ) if self._profile_enabled(): logger.info( "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " - "retrieved=%d status=%s retrieve_ms=%.2f total_ms=%.2f", + "retrieved=%d status=%s fastpath=%s retrieve_ms=%.2f " + "total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, hbm, lmc, int(ret_mask.sum().item()), "ok" if loaded else "miss", + fastpath, retrieve_ms, total_ms, ) @@ -396,6 +419,7 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: mask[:skip] = False t_store0 = time.perf_counter() + self._reset_gpu_connector_fastpath() self._engine.store( torch.tensor(toks), mask=mask, @@ -403,6 +427,7 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: req_id=str(req.req_id), ) store_ms = (time.perf_counter() - t_store0) * 1000 + fastpath = self._last_gpu_connector_fastpath() with self._lock: self._done_save.add(req.req_id) total_ms = (time.perf_counter() - t_total0) * 1000 @@ -413,17 +438,19 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: status="ok", toks=len(toks), skip=skip, + fastpath=fastpath, store_ms=f"{store_ms:.2f}", total_ms=f"{total_ms:.2f}", ) if self._profile_enabled(): logger.info( "[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d skip=%d " - "store_ms=%.2f total_ms=%.2f", + "fastpath=%s store_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, len(toks), skip, + fastpath, store_ms, total_ms, ) diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index b9aef991dd..8ccad58ee8 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -91,6 +91,7 @@ def __init__(self, kv_caches: dict) -> None: self._tls = threading.local() self._native_stitch = None self._native_split = None + self._native_kv_staging = None if self.layout == "segment_indexed" and os.environ.get( "OFFLOAD_NATIVE_STITCH", "0" ).lower() not in ("0", "false", "no", "off"): @@ -105,6 +106,20 @@ def __init__(self, kv_caches: dict) -> None: "ATOMKVByteCodec: native stitch unavailable; using torch stitch", exc_info=True, ) + if self._device.type == "cuda" and os.environ.get( + "OFFLOAD_NATIVE_KV_STAGING", "0" + ).lower() not in ("0", "false", "no", "off"): + try: + from atom.kv_transfer.offload import native_kv_staging + + native_kv_staging.load_extension() + self._native_kv_staging = native_kv_staging + except Exception: + logger.warning( + "ATOMKVByteCodec: native KV staging unavailable; " + "using chunk fallback", + exc_info=True, + ) @property def segments_per_block(self) -> int: @@ -114,6 +129,10 @@ def segments_per_block(self) -> int: def device(self) -> torch.device: return self._device + @property + def has_native_chunk_major_staging(self) -> bool: + return self._native_kv_staging is not None + def copy_calls_for_blocks(self, nblocks: int) -> int: return int(nblocks) * len(self._segments) @@ -198,6 +217,20 @@ def _normalize_block_ids(self, block_ids: list[int]) -> list[int]: ) return normalized + def _normalize_block_id_groups( + self, + block_id_groups: list[list[int]], + *, + reject_repeated: bool, + ) -> tuple[list[list[int]], list[int], list[int]]: + groups = [ + self._normalize_block_ids(list(block_ids)) for block_ids in block_id_groups + ] + flat = [bid for block_ids in groups for bid in block_ids] + if reject_repeated and len(set(flat)) != len(flat): + raise ValueError("ATOMKVByteCodec: duplicate block ids are not supported") + return groups, flat, [len(block_ids) for block_ids in groups] + def _validate_host_buf(self, host_buf: torch.Tensor, nblocks: int) -> None: if host_buf.dtype != torch.uint8: raise TypeError("ATOMKVByteCodec: host_buf must be a uint8 tensor") @@ -457,13 +490,11 @@ def gpu_to_device_buffer( with stream_ctx: idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) bases = self._segment_bases(len(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): + for seg, base, nb in zip(self._segments, bases, self._seg_block_bytes): mat = self._segment_bytes_matrix(seg) - dst = device_buf[ - base : base + len(block_ids) * nb - ].reshape(len(block_ids), nb) + dst = device_buf[base : base + len(block_ids) * nb].reshape( + len(block_ids), nb + ) torch.index_select(mat, 0, idx, out=dst) def device_buffer_to_gpu( @@ -482,15 +513,95 @@ def device_buffer_to_gpu( with stream_ctx: idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) bases = self._segment_bases(len(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): + for seg, base, nb in zip(self._segments, bases, self._seg_block_bytes): mat = self._segment_bytes_matrix(seg) - src = device_buf[ - base : base + len(block_ids) * nb - ].reshape(len(block_ids), nb) + src = device_buf[base : base + len(block_ids) * nb].reshape( + len(block_ids), nb + ) mat.index_copy_(0, idx, src) + def gpu_to_chunk_major_device_buffer( + self, + device_buf: torch.Tensor, + block_id_groups: list[list[int]], + stream: torch.cuda.Stream | None = None, + ) -> None: + """Gather ATOM KV blocks into a chunk-major device staging buffer. + + Layout is MemoryObj-compatible: + ``[chunk0: seg0 blocks | seg1 blocks | ...][chunk1: ...]``. + Native fused staging is used when available; otherwise this method + provides a reference implementation for tests and CPU fallback. + """ + groups, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups( + block_id_groups, + reject_repeated=True, + ) + self._validate_device_buf(device_buf, len(flat_block_ids)) + if not flat_block_ids: + return + with self._device_ctx(): + stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with stream_ctx: + if self._native_kv_staging is not None: + self._native_kv_staging.fused_pack_chunk_major( + self._segments, + self._seg_block_bytes, + chunk_block_counts, + flat_block_ids, + device_buf, + ) + return + + offset = 0 + for block_ids in groups: + nblocks = len(block_ids) + chunk_nbytes = nblocks * self.bytes_per_block + self.gpu_to_device_buffer( + device_buf[offset : offset + chunk_nbytes], + block_ids, + stream=stream, + ) + offset += chunk_nbytes + + def chunk_major_device_buffer_to_gpu( + self, + device_buf: torch.Tensor, + block_id_groups: list[list[int]], + stream: torch.cuda.Stream | None = None, + ) -> None: + """Scatter a chunk-major device staging buffer into ATOM KV blocks.""" + groups, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups( + block_id_groups, + reject_repeated=True, + ) + self._validate_device_buf(device_buf, len(flat_block_ids)) + if not flat_block_ids: + return + with self._device_ctx(): + stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() + with stream_ctx: + if self._native_kv_staging is not None: + self._native_kv_staging.fused_unpack_chunk_major( + device_buf, + self._segments, + self._seg_block_bytes, + chunk_block_counts, + flat_block_ids, + ) + return + + offset = 0 + for block_ids in groups: + nblocks = len(block_ids) + chunk_nbytes = nblocks * self.bytes_per_block + self.device_buffer_to_gpu( + device_buf[offset : offset + chunk_nbytes], + block_ids, + stream=stream, + ) + offset += chunk_nbytes + class _nullctx: def __enter__(self): diff --git a/atom/kv_transfer/offload/lmcache_compat.py b/atom/kv_transfer/offload/lmcache_compat.py index cb8b0d9265..0be2008524 100644 --- a/atom/kv_transfer/offload/lmcache_compat.py +++ b/atom/kv_transfer/offload/lmcache_compat.py @@ -129,6 +129,15 @@ def __init__(self, codec: ATOMKVByteCodec, block_size: int) -> None: self.device = torch.device(codec.device) self._tls = threading.local() + def _set_last_fastpath(self, value: str) -> None: + self._tls.last_fastpath = value + + def reset_fastpath(self) -> None: + self._set_last_fastpath("none") + + def last_fastpath(self) -> str: + return getattr(self._tls, "last_fastpath", "unknown") + def _use_cuda(self) -> bool: return self.device.type == "cuda" @@ -180,6 +189,9 @@ def _ensure_host_tmp( state.host_tmp = torch.empty((nbytes,), dtype=torch.uint8) return state.host_tmp[:nbytes] + def _can_use_fused_chunk_major(self) -> bool: + return self._use_cuda() and self.codec.has_native_chunk_major_staging + def _memory_tensor(self, memory_obj: Any, nbytes: int) -> torch.Tensor: tensor = getattr(memory_obj, "tensor", None) if tensor is None and hasattr(memory_obj, "get_tensor"): @@ -192,7 +204,9 @@ def _memory_tensor(self, memory_obj: Any, nbytes: int) -> torch.Tensor: f"got {tensor.dtype}" ) if not tensor.is_contiguous(): - raise RuntimeError("ATOM LMCache connector: MemoryObj tensor not contiguous") + raise RuntimeError( + "ATOM LMCache connector: MemoryObj tensor not contiguous" + ) flat = tensor.reshape(-1) if int(flat.numel()) < int(nbytes): raise ValueError( @@ -275,7 +289,6 @@ def batched_from_gpu( slot = self._next_slot(state) device_buf = self._ensure_slot(slot, total_nbytes) - host_buf = self._ensure_host_tmp(state, total_nbytes) dst_tensors = [ self._memory_tensor( memory_obj, @@ -288,6 +301,38 @@ def batched_from_gpu( ) ] + if self._can_use_fused_chunk_major(): + self._set_last_fastpath("fused_chunk") + if slot.free_event_valid: + state.pack_stream.wait_event(slot.free_event) + with state.stream_ctx(state.pack_stream): + self.codec.gpu_to_chunk_major_device_buffer( + device_buf, + block_id_groups, + stream=state.pack_stream, + ) + slot.ready_event.record(state.pack_stream) + state.copy_stream.wait_event(slot.ready_event) + with state.stream_ctx(state.copy_stream): + offset = 0 + for dst, block_count in zip( + dst_tensors, + chunk_block_counts, + strict=True, + ): + nbytes = block_count * self.codec.bytes_per_block + dst.copy_( + device_buf[offset : offset + nbytes], + non_blocking=True, + ) + offset += nbytes + slot.free_event.record(state.copy_stream) + slot.free_event_valid = True + state.copy_stream.synchronize() + return + + self._set_last_fastpath("chunk") + host_buf = self._ensure_host_tmp(state, total_nbytes) if use_cuda: if slot.free_event_valid: state.pack_stream.wait_event(slot.free_event) @@ -338,7 +383,6 @@ def batched_to_gpu( slot = self._next_slot(state) device_buf = self._ensure_slot(slot, total_nbytes) - host_buf = self._ensure_host_tmp(state, total_nbytes) src_tensors = [ self._memory_tensor( memory_obj, @@ -350,8 +394,40 @@ def batched_to_gpu( strict=True, ) ] - self.codec.stitch_chunk_buffers(host_buf, src_tensors, chunk_block_counts) + if self._can_use_fused_chunk_major(): + self._set_last_fastpath("fused_chunk") + if slot.free_event_valid: + state.copy_stream.wait_event(slot.free_event) + with state.stream_ctx(state.copy_stream): + offset = 0 + for src, block_count in zip( + src_tensors, + chunk_block_counts, + strict=True, + ): + nbytes = block_count * self.codec.bytes_per_block + device_buf[offset : offset + nbytes].copy_( + src, + non_blocking=True, + ) + offset += nbytes + slot.ready_event.record(state.copy_stream) + state.pack_stream.wait_event(slot.ready_event) + with state.stream_ctx(state.pack_stream): + self.codec.chunk_major_device_buffer_to_gpu( + device_buf, + block_id_groups, + stream=state.pack_stream, + ) + slot.free_event.record(state.pack_stream) + slot.free_event_valid = True + state.pack_stream.synchronize() + return + + self._set_last_fastpath("chunk") + host_buf = self._ensure_host_tmp(state, total_nbytes) + self.codec.stitch_chunk_buffers(host_buf, src_tensors, chunk_block_counts) if use_cuda: if slot.free_event_valid: state.copy_stream.wait_event(slot.free_event) diff --git a/atom/kv_transfer/offload/native_kv_staging.cpp b/atom/kv_transfer/offload/native_kv_staging.cpp new file mode 100644 index 0000000000..9dd558354d --- /dev/null +++ b/atom/kv_transfer/offload/native_kv_staging.cpp @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include +#include + +#define ATOM_HIP_CHECK(cmd) \ + do { \ + hipError_t err = (cmd); \ + TORCH_CHECK( \ + err == hipSuccess, "HIP error: ", hipGetErrorString(err)); \ + } while (0) + +hipError_t launch_fused_pack_chunk_major( + uint8_t* device_buf, + const int64_t* segment_ptrs, + const int64_t* segment_block_bytes, + const int64_t* segment_prefix_bytes, + const int64_t* chunk_block_counts, + const int64_t* chunk_block_offsets, + const int64_t* chunk_output_bases, + const int64_t* block_ids, + int64_t num_chunks, + int64_t num_segments, + hipStream_t stream); + +hipError_t launch_fused_unpack_chunk_major( + const uint8_t* device_buf, + const int64_t* segment_ptrs, + const int64_t* segment_block_bytes, + const int64_t* segment_prefix_bytes, + const int64_t* chunk_block_counts, + const int64_t* chunk_block_offsets, + const int64_t* chunk_output_bases, + const int64_t* block_ids, + int64_t num_chunks, + int64_t num_segments, + hipStream_t stream); + +namespace { + +torch::Tensor make_device_i64( + const std::vector& values, + const torch::Device& device, + hipStream_t stream) { + auto tensor = torch::empty( + {static_cast(values.size())}, + torch::TensorOptions().dtype(torch::kInt64).device(device)); + if (!values.empty()) { + ATOM_HIP_CHECK(hipMemcpyAsync( + tensor.data_ptr(), + values.data(), + values.size() * sizeof(int64_t), + hipMemcpyHostToDevice, + stream)); + } + return tensor; +} + +struct StagingMeta { + torch::Tensor segment_ptrs; + torch::Tensor segment_block_bytes; + torch::Tensor segment_prefix_bytes; + torch::Tensor chunk_block_counts; + torch::Tensor chunk_block_offsets; + torch::Tensor chunk_output_bases; + torch::Tensor block_ids; + int64_t num_chunks = 0; + int64_t num_segments = 0; + int64_t total_bytes = 0; +}; + +StagingMeta build_meta( + const std::vector& segment_tensors, + const std::vector& segment_block_bytes, + const std::vector& chunk_block_counts, + const std::vector& block_ids, + torch::Tensor device_buf, + hipStream_t stream) { + TORCH_CHECK(device_buf.is_cuda(), "device_buf must be a CUDA/HIP tensor"); + TORCH_CHECK(device_buf.dtype() == torch::kUInt8, "device_buf must be uint8"); + TORCH_CHECK(device_buf.is_contiguous(), "device_buf must be contiguous"); + TORCH_CHECK( + segment_tensors.size() == segment_block_bytes.size(), + "segment_tensors and segment_block_bytes size mismatch"); + + const int64_t num_segments = static_cast(segment_tensors.size()); + const int64_t num_chunks = static_cast(chunk_block_counts.size()); + TORCH_CHECK(num_segments > 0, "at least one segment is required"); + + std::vector segment_ptr_values(num_segments); + std::vector segment_prefix_values(num_segments); + int64_t bytes_per_block = 0; + for (int64_t i = 0; i < num_segments; ++i) { + const auto& seg = segment_tensors[i]; + TORCH_CHECK(seg.is_cuda(), "segment tensor must be CUDA/HIP"); + TORCH_CHECK(seg.device() == device_buf.device(), "segment/device mismatch"); + TORCH_CHECK(seg.is_contiguous(), "segment tensor must be contiguous"); + TORCH_CHECK(segment_block_bytes[i] > 0, "segment block bytes must be > 0"); + segment_ptr_values[i] = + reinterpret_cast(static_cast(seg.data_ptr())); + segment_prefix_values[i] = bytes_per_block; + bytes_per_block += segment_block_bytes[i]; + } + + std::vector chunk_block_offsets(num_chunks); + std::vector chunk_output_bases(num_chunks); + int64_t block_offset = 0; + int64_t byte_offset = 0; + for (int64_t c = 0; c < num_chunks; ++c) { + const int64_t nblocks = chunk_block_counts[c]; + TORCH_CHECK(nblocks >= 0, "chunk block count must be non-negative"); + chunk_block_offsets[c] = block_offset; + chunk_output_bases[c] = byte_offset; + block_offset += nblocks; + byte_offset += nblocks * bytes_per_block; + } + TORCH_CHECK( + static_cast(block_ids.size()) == block_offset, + "block_ids length does not match chunk block counts"); + TORCH_CHECK( + device_buf.numel() >= byte_offset, + "device_buf is smaller than chunk-major staging output"); + + StagingMeta meta; + meta.segment_ptrs = make_device_i64(segment_ptr_values, device_buf.device(), stream); + meta.segment_block_bytes = + make_device_i64(segment_block_bytes, device_buf.device(), stream); + meta.segment_prefix_bytes = + make_device_i64(segment_prefix_values, device_buf.device(), stream); + meta.chunk_block_counts = + make_device_i64(chunk_block_counts, device_buf.device(), stream); + meta.chunk_block_offsets = + make_device_i64(chunk_block_offsets, device_buf.device(), stream); + meta.chunk_output_bases = + make_device_i64(chunk_output_bases, device_buf.device(), stream); + meta.block_ids = make_device_i64(block_ids, device_buf.device(), stream); + meta.num_chunks = num_chunks; + meta.num_segments = num_segments; + meta.total_bytes = byte_offset; + return meta; +} + +hipStream_t current_hip_stream(torch::Device device) { + c10::hip::HIPGuardMasqueradingAsCUDA guard(device); + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(device.index()); + return stream.stream(); +} + +} // namespace + +void fused_pack_chunk_major( + std::vector segment_tensors, + std::vector segment_block_bytes, + std::vector chunk_block_counts, + std::vector block_ids, + torch::Tensor device_buf) { + auto stream = current_hip_stream(device_buf.device()); + auto meta = build_meta( + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, + device_buf, + stream); + if (meta.total_bytes == 0) { + return; + } + ATOM_HIP_CHECK(launch_fused_pack_chunk_major( + device_buf.data_ptr(), + meta.segment_ptrs.data_ptr(), + meta.segment_block_bytes.data_ptr(), + meta.segment_prefix_bytes.data_ptr(), + meta.chunk_block_counts.data_ptr(), + meta.chunk_block_offsets.data_ptr(), + meta.chunk_output_bases.data_ptr(), + meta.block_ids.data_ptr(), + meta.num_chunks, + meta.num_segments, + stream)); +} + +void fused_unpack_chunk_major( + torch::Tensor device_buf, + std::vector segment_tensors, + std::vector segment_block_bytes, + std::vector chunk_block_counts, + std::vector block_ids) { + auto stream = current_hip_stream(device_buf.device()); + auto meta = build_meta( + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, + device_buf, + stream); + if (meta.total_bytes == 0) { + return; + } + ATOM_HIP_CHECK(launch_fused_unpack_chunk_major( + device_buf.data_ptr(), + meta.segment_ptrs.data_ptr(), + meta.segment_block_bytes.data_ptr(), + meta.segment_prefix_bytes.data_ptr(), + meta.chunk_block_counts.data_ptr(), + meta.chunk_block_offsets.data_ptr(), + meta.chunk_output_bases.data_ptr(), + meta.block_ids.data_ptr(), + meta.num_chunks, + meta.num_segments, + stream)); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_pack_chunk_major", &fused_pack_chunk_major); + m.def("fused_unpack_chunk_major", &fused_unpack_chunk_major); +} diff --git a/atom/kv_transfer/offload/native_kv_staging.py b/atom/kv_transfer/offload/native_kv_staging.py new file mode 100644 index 0000000000..73b33fdf5b --- /dev/null +++ b/atom/kv_transfer/offload/native_kv_staging.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Optional HIP fused chunk-major staging for ATOM LMCache offload.""" + +from __future__ import annotations + +from pathlib import Path + +from torch.utils.cpp_extension import load + +_EXT = None + + +def _load_ext(): + global _EXT + if _EXT is None: + base = Path(__file__).parent + _EXT = load( + name="atom_lmcache_native_kv_staging", + sources=[ + str(base / "native_kv_staging.cpp"), + str(base / "native_kv_staging_kernel.hip"), + ], + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3"], + verbose=False, + ) + return _EXT + + +def load_extension() -> None: + _load_ext() + + +def fused_pack_chunk_major( + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, + device_buf, +) -> None: + _load_ext().fused_pack_chunk_major( + segment_tensors, + [int(x) for x in segment_block_bytes], + [int(x) for x in chunk_block_counts], + [int(x) for x in block_ids], + device_buf, + ) + + +def fused_unpack_chunk_major( + device_buf, + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, +) -> None: + _load_ext().fused_unpack_chunk_major( + device_buf, + segment_tensors, + [int(x) for x in segment_block_bytes], + [int(x) for x in chunk_block_counts], + [int(x) for x in block_ids], + ) diff --git a/atom/kv_transfer/offload/native_kv_staging_kernel.hip b/atom/kv_transfer/offload/native_kv_staging_kernel.hip new file mode 100644 index 0000000000..652276be3e --- /dev/null +++ b/atom/kv_transfer/offload/native_kv_staging_kernel.hip @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include + +namespace { + +constexpr int kThreads = 256; + +__global__ void pack_chunk_major_kernel( + uint8_t* __restrict__ device_buf, + const int64_t* __restrict__ segment_ptrs, + const int64_t* __restrict__ segment_block_bytes, + const int64_t* __restrict__ segment_prefix_bytes, + const int64_t* __restrict__ chunk_block_counts, + const int64_t* __restrict__ chunk_block_offsets, + const int64_t* __restrict__ chunk_output_bases, + const int64_t* __restrict__ block_ids, + int64_t num_segments) { + const int64_t job = static_cast(blockIdx.x); + const int64_t chunk_id = job / num_segments; + const int64_t seg_id = job - chunk_id * num_segments; + const int64_t nblocks = chunk_block_counts[chunk_id]; + const int64_t seg_bytes = segment_block_bytes[seg_id]; + const int64_t nbytes = nblocks * seg_bytes; + if (nbytes <= 0) { + return; + } + + const auto* src_base = + reinterpret_cast(segment_ptrs[seg_id]); + uint8_t* dst_base = device_buf + chunk_output_bases[chunk_id] + + segment_prefix_bytes[seg_id] * nblocks; + const int64_t block_offset = chunk_block_offsets[chunk_id]; + + for (int64_t i = static_cast(threadIdx.x); i < nbytes; + i += static_cast(blockDim.x)) { + const int64_t local_block = i / seg_bytes; + const int64_t byte_in_block = i - local_block * seg_bytes; + const int64_t physical_block = block_ids[block_offset + local_block]; + dst_base[i] = src_base[physical_block * seg_bytes + byte_in_block]; + } +} + +__global__ void unpack_chunk_major_kernel( + const uint8_t* __restrict__ device_buf, + const int64_t* __restrict__ segment_ptrs, + const int64_t* __restrict__ segment_block_bytes, + const int64_t* __restrict__ segment_prefix_bytes, + const int64_t* __restrict__ chunk_block_counts, + const int64_t* __restrict__ chunk_block_offsets, + const int64_t* __restrict__ chunk_output_bases, + const int64_t* __restrict__ block_ids, + int64_t num_segments) { + const int64_t job = static_cast(blockIdx.x); + const int64_t chunk_id = job / num_segments; + const int64_t seg_id = job - chunk_id * num_segments; + const int64_t nblocks = chunk_block_counts[chunk_id]; + const int64_t seg_bytes = segment_block_bytes[seg_id]; + const int64_t nbytes = nblocks * seg_bytes; + if (nbytes <= 0) { + return; + } + + const auto* src_base = device_buf + chunk_output_bases[chunk_id] + + segment_prefix_bytes[seg_id] * nblocks; + auto* dst_base = reinterpret_cast(segment_ptrs[seg_id]); + const int64_t block_offset = chunk_block_offsets[chunk_id]; + + for (int64_t i = static_cast(threadIdx.x); i < nbytes; + i += static_cast(blockDim.x)) { + const int64_t local_block = i / seg_bytes; + const int64_t byte_in_block = i - local_block * seg_bytes; + const int64_t physical_block = block_ids[block_offset + local_block]; + dst_base[physical_block * seg_bytes + byte_in_block] = src_base[i]; + } +} + +} // namespace + +hipError_t launch_fused_pack_chunk_major( + uint8_t* device_buf, + const int64_t* segment_ptrs, + const int64_t* segment_block_bytes, + const int64_t* segment_prefix_bytes, + const int64_t* chunk_block_counts, + const int64_t* chunk_block_offsets, + const int64_t* chunk_output_bases, + const int64_t* block_ids, + int64_t num_chunks, + int64_t num_segments, + hipStream_t stream) { + const dim3 grid(static_cast(num_chunks * num_segments)); + const dim3 block(kThreads); + hipLaunchKernelGGL( + pack_chunk_major_kernel, + grid, + block, + 0, + stream, + device_buf, + segment_ptrs, + segment_block_bytes, + segment_prefix_bytes, + chunk_block_counts, + chunk_block_offsets, + chunk_output_bases, + block_ids, + num_segments); + return hipGetLastError(); +} + +hipError_t launch_fused_unpack_chunk_major( + const uint8_t* device_buf, + const int64_t* segment_ptrs, + const int64_t* segment_block_bytes, + const int64_t* segment_prefix_bytes, + const int64_t* chunk_block_counts, + const int64_t* chunk_block_offsets, + const int64_t* chunk_output_bases, + const int64_t* block_ids, + int64_t num_chunks, + int64_t num_segments, + hipStream_t stream) { + const dim3 grid(static_cast(num_chunks * num_segments)); + const dim3 block(kThreads); + hipLaunchKernelGGL( + unpack_chunk_major_kernel, + grid, + block, + 0, + stream, + device_buf, + segment_ptrs, + segment_block_bytes, + segment_prefix_bytes, + chunk_block_counts, + chunk_block_offsets, + chunk_output_bases, + block_ids, + num_segments); + return hipGetLastError(); +} diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index c278fb897a..43346b7b5c 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -227,7 +227,9 @@ def test_codec_device_buffer_roundtrip_noncontiguous_blocks(monkeypatch): ), } kv_caches = { - name: SimpleNamespace(k_cache=layer.k_cache.clone(), v_cache=layer.v_cache.clone()) + name: SimpleNamespace( + k_cache=layer.k_cache.clone(), v_cache=layer.v_cache.clone() + ) for name, layer in original.items() } for layer in kv_caches.values(): @@ -307,6 +309,265 @@ def test_lmcache_connector_maps_token_ranges_to_block_ids(monkeypatch): assert torch.count_nonzero(kv_caches["l0"].v_cache[0]) == 0 +def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): + from contextlib import nullcontext + + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + original = { + "l0": SimpleNamespace( + k_cache=torch.arange(6 * 2, dtype=torch.uint8).reshape(6, 2), + v_cache=(torch.arange(6 * 3, dtype=torch.uint8).reshape(6, 3) + 51), + k_scale=None, + v_scale=None, + ) + } + kv_caches = { + "l0": SimpleNamespace( + k_cache=original["l0"].k_cache.clone(), + v_cache=original["l0"].v_cache.clone(), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + connector = ATOMLMCacheGPUConnector(codec, block_size=4) + monkeypatch.setattr(connector, "_can_use_fused_chunk_major", lambda: True) + orig_pack = codec.gpu_to_chunk_major_device_buffer + orig_unpack = codec.chunk_major_device_buffer_to_gpu + monkeypatch.setattr( + codec, + "gpu_to_chunk_major_device_buffer", + lambda device_buf, block_id_groups, stream=None: orig_pack( + device_buf, block_id_groups, stream=None + ), + ) + monkeypatch.setattr( + codec, + "chunk_major_device_buffer_to_gpu", + lambda device_buf, block_id_groups, stream=None: orig_unpack( + device_buf, block_id_groups, stream=None + ), + ) + + class _FakeEvent: + def record(self, stream) -> None: + pass + + class _FakeStream: + def wait_event(self, event) -> None: + pass + + def synchronize(self) -> None: + pass + + class _FakeState: + def __init__(self) -> None: + self.use_cuda = True + self.pack_stream = _FakeStream() + self.copy_stream = _FakeStream() + self.next_slot = 0 + self.slots = [ + SimpleNamespace( + tensor=None, + ready_event=_FakeEvent(), + free_event=_FakeEvent(), + free_event_valid=False, + ), + SimpleNamespace( + tensor=None, + ready_event=_FakeEvent(), + free_event=_FakeEvent(), + free_event_valid=False, + ), + ] + + def stream_ctx(self, stream): + return nullcontext() + + fake_state = _FakeState() + monkeypatch.setattr(connector, "_thread_state", lambda: fake_state) + memory_objs = [ + SimpleNamespace( + tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) + ), + SimpleNamespace( + tensor=torch.empty(1 * codec.bytes_per_block, dtype=torch.uint8) + ), + ] + + connector.batched_from_gpu( + memory_objs, + [4, 12], + [12, 16], + block_ids=[0, 1, 2, 3, 4, 5], + ) + + expected0 = torch.cat( + [ + original["l0"].k_cache[[1, 2]].reshape(-1), + original["l0"].v_cache[[1, 2]].reshape(-1), + ] + ) + expected1 = torch.cat( + [ + original["l0"].k_cache[[3]].reshape(-1), + original["l0"].v_cache[[3]].reshape(-1), + ] + ) + assert connector.last_fastpath() == "fused_chunk" + assert torch.equal(memory_objs[0].tensor, expected0) + assert torch.equal(memory_objs[1].tensor, expected1) + + kv_caches["l0"].k_cache.zero_() + kv_caches["l0"].v_cache.zero_() + connector.batched_to_gpu( + memory_objs, + [4, 12], + [12, 16], + block_ids=[0, 1, 2, 3, 4, 5], + ) + + assert connector.last_fastpath() == "fused_chunk" + for bid in [1, 2, 3]: + assert torch.equal(kv_caches["l0"].k_cache[bid], original["l0"].k_cache[bid]) + assert torch.equal(kv_caches["l0"].v_cache[bid], original["l0"].v_cache[bid]) + assert torch.count_nonzero(kv_caches["l0"].k_cache[0]) == 0 + assert torch.count_nonzero(kv_caches["l0"].v_cache[0]) == 0 + + +def test_codec_chunk_major_device_buffer_layout(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + original = { + "l0": SimpleNamespace( + k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + v_cache=(torch.arange(4 * 3, dtype=torch.uint8).reshape(4, 3) + 51), + k_scale=None, + v_scale=None, + ) + } + kv_caches = { + "l0": SimpleNamespace( + k_cache=original["l0"].k_cache.clone(), + v_cache=original["l0"].v_cache.clone(), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + block_id_groups = [[0, 1], [2, 3]] + device_buf = torch.empty( + 4 * codec.bytes_per_block, + dtype=torch.uint8, + device=codec.device, + ) + + codec.gpu_to_chunk_major_device_buffer(device_buf, block_id_groups) + + expected = torch.cat( + [ + original["l0"].k_cache[[0, 1]].reshape(-1), + original["l0"].v_cache[[0, 1]].reshape(-1), + original["l0"].k_cache[[2, 3]].reshape(-1), + original["l0"].v_cache[[2, 3]].reshape(-1), + ] + ) + assert torch.equal(device_buf.cpu(), expected.cpu()) + + kv_caches["l0"].k_cache.zero_() + kv_caches["l0"].v_cache.zero_() + codec.chunk_major_device_buffer_to_gpu(device_buf, block_id_groups) + + assert torch.equal(kv_caches["l0"].k_cache, original["l0"].k_cache) + assert torch.equal(kv_caches["l0"].v_cache, original["l0"].v_cache) + + +def test_codec_chunk_major_handles_tail_and_sparse_blocks(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + original = { + "l0": SimpleNamespace( + k_cache=torch.arange(6 * 2, dtype=torch.uint8).reshape(6, 2), + v_cache=(torch.arange(6 * 4, dtype=torch.uint8).reshape(6, 4) + 31), + k_scale=(torch.arange(6, dtype=torch.uint8).reshape(6, 1) + 101), + v_scale=None, + ), + "l1": SimpleNamespace( + k_cache=(torch.arange(6 * 3, dtype=torch.uint8).reshape(6, 3) + 151), + v_cache=(torch.arange(6 * 2, dtype=torch.uint8).reshape(6, 2) + 201), + k_scale=None, + v_scale=None, + ), + } + kv_caches = { + name: SimpleNamespace( + k_cache=layer.k_cache.clone(), + v_cache=layer.v_cache.clone(), + k_scale=layer.k_scale.clone() if layer.k_scale is not None else None, + v_scale=None, + ) + for name, layer in original.items() + } + codec = ATOMKVByteCodec(kv_caches) + block_id_groups = [[4, 1, 3], [0]] + device_buf = torch.empty( + 4 * codec.bytes_per_block, + dtype=torch.uint8, + device=codec.device, + ) + + codec.gpu_to_chunk_major_device_buffer(device_buf, block_id_groups) + for layer in kv_caches.values(): + layer.k_cache.zero_() + layer.v_cache.zero_() + if layer.k_scale is not None: + layer.k_scale.zero_() + codec.chunk_major_device_buffer_to_gpu(device_buf, block_id_groups) + + for name, layer in kv_caches.items(): + src = original[name] + for bid in [4, 1, 3, 0]: + assert torch.equal(layer.k_cache[bid], src.k_cache[bid]) + assert torch.equal(layer.v_cache[bid], src.v_cache[bid]) + if layer.k_scale is not None: + assert torch.equal(layer.k_scale[bid], src.k_scale[bid]) + + +def test_codec_chunk_major_rejects_duplicate_block_ids(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + kv_caches = { + "l0": SimpleNamespace( + k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + v_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + device_buf = torch.empty(3 * codec.bytes_per_block, dtype=torch.uint8) + + with pytest.raises(ValueError, match="duplicate block ids"): + codec.gpu_to_chunk_major_device_buffer(device_buf, [[0, 1], [1]]) + + @pytest.mark.parametrize("layout", ["block", "segment", "segment_indexed"]) @pytest.mark.parametrize("method_name", ["gpu_to_host", "host_to_gpu"]) def test_codec_rejects_invalid_block_ids_before_copy(monkeypatch, layout, method_name): From 7b39986f6ad46cf113ff859af4b9774dd23e39f7 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Wed, 3 Jun 2026 22:19:34 -0500 Subject: [PATCH 14/27] Add bounded LMCache staging with chunk2 default --- atom/kv_transfer/offload/connector.py | 99 ++- atom/kv_transfer/offload/lmcache_compat.py | 677 +++++++++++++++---- atom/model_ops/attentions/aiter_attention.py | 49 ++ atom/model_ops/attentions/aiter_mla.py | 8 + atom/model_ops/attentions/backends.py | 10 + tests/test_lmcache_offload_connector.py | 255 ++++++- 6 files changed, 937 insertions(+), 161 deletions(-) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index f796640160..cb322296a4 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -121,7 +121,11 @@ def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: atom_block_size=self.block_size, bytes_per_block=self._codec.bytes_per_block, ) - gpu_connector = ATOMLMCacheGPUConnector(self._codec, self.block_size) + gpu_connector = ATOMLMCacheGPUConnector( + self._codec, + self.block_size, + chunk_size=self.chunk_size, + ) self._engine = LMCacheEngineBuilder.get_or_create( f"atom-offload-{rank}", @@ -171,11 +175,19 @@ def _logged_lookup(*a, **k): logger.info( "LMCache offload worker rank=%d: bytes_per_block=%d chunk=%d " - "codec_layout=%s save=%s load=%s", + "codec_layout=%s gpu_staging_slots=%d " + "gpu_staging_chunk_bytes=%d gpu_staging_group_chunks=%d " + "gpu_staging_capacity_bytes=%d release_gpu_staging=%s " + "save=%s load=%s", rank, self._codec.bytes_per_block, self.chunk_size, self._codec.layout, + gpu_connector.staging_slots, + gpu_connector.gpu_staging_chunk_bytes, + gpu_connector.gpu_staging_group_chunks, + gpu_connector.gpu_staging_capacity_bytes, + gpu_connector.release_gpu_staging_after_transfer, self._do_save, self._do_load, ) @@ -276,6 +288,15 @@ def _last_gpu_connector_fastpath(self) -> str: except Exception: return "unknown" + def _last_gpu_connector_transfer_stats(self) -> dict[str, int | float]: + gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None) + if gpu_connector is None or not hasattr(gpu_connector, "last_transfer_stats"): + return {} + try: + return dict(gpu_connector.last_transfer_stats()) + except Exception: + return {} + def _reset_gpu_connector_fastpath(self) -> None: gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None) if gpu_connector is None or not hasattr(gpu_connector, "reset_fastpath"): @@ -350,6 +371,7 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: ) retrieve_ms = (time.perf_counter() - t_retrieve0) * 1000 fastpath = self._last_gpu_connector_fastpath() + transfer_stats = self._last_gpu_connector_transfer_stats() self._lookup_unpin(req.req_id) loaded = bool(ret_mask[hbm:lmc].all().item()) if lmc > hbm else True with self._lock: @@ -367,14 +389,34 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: lmc=lmc, retrieved=int(ret_mask.sum().item()), fastpath=fastpath, + chunks=transfer_stats.get("chunks", 0), + groups=transfer_stats.get("groups", 0), + max_chunk_bytes=transfer_stats.get("max_chunk_bytes", 0), + max_group_bytes=transfer_stats.get("max_group_bytes", 0), + gpu_staging_chunk_bytes=transfer_stats.get("gpu_staging_chunk_bytes", 0), + gpu_staging_group_chunks=transfer_stats.get("gpu_staging_group_chunks", 0), + gpu_staging_capacity_bytes=transfer_stats.get( + "gpu_staging_capacity_bytes", 0 + ), + total_bytes=transfer_stats.get("total_bytes", 0), + pack_ms=f"{float(transfer_stats.get('pack_ms', 0.0)):.2f}", + copy_ms=f"{float(transfer_stats.get('copy_ms', 0.0)):.2f}", + sync_ms=f"{float(transfer_stats.get('sync_ms', 0.0)):.2f}", + transfer_ms=f"{float(transfer_stats.get('transfer_ms', 0.0)):.2f}", + effective_gbps=f"{float(transfer_stats.get('effective_gbps', 0.0)):.2f}", retrieve_ms=f"{retrieve_ms:.2f}", total_ms=f"{total_ms:.2f}", ) if self._profile_enabled(): logger.info( "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " - "retrieved=%d status=%s fastpath=%s retrieve_ms=%.2f " - "total_ms=%.2f", + "retrieved=%d status=%s fastpath=%s chunks=%d " + "groups=%d max_chunk_bytes=%d max_group_bytes=%d " + "gpu_staging_chunk_bytes=%d gpu_staging_group_chunks=%d " + "gpu_staging_capacity_bytes=%d total_bytes=%d " + "pack_ms=%.2f copy_ms=%.2f sync_ms=%.2f " + "transfer_ms=%.2f effective_gbps=%.2f " + "retrieve_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, hbm, @@ -382,6 +424,19 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: int(ret_mask.sum().item()), "ok" if loaded else "miss", fastpath, + int(transfer_stats.get("chunks", 0)), + int(transfer_stats.get("groups", 0)), + int(transfer_stats.get("max_chunk_bytes", 0)), + int(transfer_stats.get("max_group_bytes", 0)), + int(transfer_stats.get("gpu_staging_chunk_bytes", 0)), + int(transfer_stats.get("gpu_staging_group_chunks", 0)), + int(transfer_stats.get("gpu_staging_capacity_bytes", 0)), + int(transfer_stats.get("total_bytes", 0)), + float(transfer_stats.get("pack_ms", 0.0)), + float(transfer_stats.get("copy_ms", 0.0)), + float(transfer_stats.get("sync_ms", 0.0)), + float(transfer_stats.get("transfer_ms", 0.0)), + float(transfer_stats.get("effective_gbps", 0.0)), retrieve_ms, total_ms, ) @@ -428,6 +483,7 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: ) store_ms = (time.perf_counter() - t_store0) * 1000 fastpath = self._last_gpu_connector_fastpath() + transfer_stats = self._last_gpu_connector_transfer_stats() with self._lock: self._done_save.add(req.req_id) total_ms = (time.perf_counter() - t_total0) * 1000 @@ -439,18 +495,51 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: toks=len(toks), skip=skip, fastpath=fastpath, + chunks=transfer_stats.get("chunks", 0), + groups=transfer_stats.get("groups", 0), + max_chunk_bytes=transfer_stats.get("max_chunk_bytes", 0), + max_group_bytes=transfer_stats.get("max_group_bytes", 0), + gpu_staging_chunk_bytes=transfer_stats.get("gpu_staging_chunk_bytes", 0), + gpu_staging_group_chunks=transfer_stats.get("gpu_staging_group_chunks", 0), + gpu_staging_capacity_bytes=transfer_stats.get( + "gpu_staging_capacity_bytes", 0 + ), + total_bytes=transfer_stats.get("total_bytes", 0), + pack_ms=f"{float(transfer_stats.get('pack_ms', 0.0)):.2f}", + copy_ms=f"{float(transfer_stats.get('copy_ms', 0.0)):.2f}", + sync_ms=f"{float(transfer_stats.get('sync_ms', 0.0)):.2f}", + transfer_ms=f"{float(transfer_stats.get('transfer_ms', 0.0)):.2f}", + effective_gbps=f"{float(transfer_stats.get('effective_gbps', 0.0)):.2f}", store_ms=f"{store_ms:.2f}", total_ms=f"{total_ms:.2f}", ) if self._profile_enabled(): logger.info( "[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d skip=%d " - "fastpath=%s store_ms=%.2f total_ms=%.2f", + "fastpath=%s chunks=%d groups=%d max_chunk_bytes=%d " + "max_group_bytes=%d gpu_staging_chunk_bytes=%d " + "gpu_staging_group_chunks=%d gpu_staging_capacity_bytes=%d " + "total_bytes=%d pack_ms=%.2f copy_ms=%.2f sync_ms=%.2f " + "transfer_ms=%.2f effective_gbps=%.2f " + "store_ms=%.2f total_ms=%.2f", getattr(self, "_rank", "?"), req.req_id, len(toks), skip, fastpath, + int(transfer_stats.get("chunks", 0)), + int(transfer_stats.get("groups", 0)), + int(transfer_stats.get("max_chunk_bytes", 0)), + int(transfer_stats.get("max_group_bytes", 0)), + int(transfer_stats.get("gpu_staging_chunk_bytes", 0)), + int(transfer_stats.get("gpu_staging_group_chunks", 0)), + int(transfer_stats.get("gpu_staging_capacity_bytes", 0)), + int(transfer_stats.get("total_bytes", 0)), + float(transfer_stats.get("pack_ms", 0.0)), + float(transfer_stats.get("copy_ms", 0.0)), + float(transfer_stats.get("sync_ms", 0.0)), + float(transfer_stats.get("transfer_ms", 0.0)), + float(transfer_stats.get("effective_gbps", 0.0)), store_ms, total_ms, ) diff --git a/atom/kv_transfer/offload/lmcache_compat.py b/atom/kv_transfer/offload/lmcache_compat.py index 0be2008524..fad16df602 100644 --- a/atom/kv_transfer/offload/lmcache_compat.py +++ b/atom/kv_transfer/offload/lmcache_compat.py @@ -12,7 +12,10 @@ from __future__ import annotations +from dataclasses import dataclass +import os import threading +import time from typing import Any import torch @@ -96,8 +99,43 @@ def __init__(self, use_cuda: bool) -> None: self.free_event = torch.cuda.Event(blocking=False) +def _env_flag(name: str, default: str = "0") -> bool: + return os.environ.get(name, default).lower() not in ("0", "false", "no", "off") + + +def _env_int(name: str, default: int, *, min_value: int = 1) -> int: + raw = os.environ.get(name) + if raw is None: + return default + try: + value = int(raw) + except ValueError as exc: + raise ValueError(f"{name} must be an integer, got {raw!r}") from exc + if value < min_value: + raise ValueError(f"{name} must be >= {min_value}, got {value}") + return value + + +def _env_optional_int(name: str, *, min_value: int = 1) -> int | None: + raw = os.environ.get(name) + if raw is None or raw == "": + return None + try: + value = int(raw) + except ValueError as exc: + raise ValueError(f"{name} must be an integer, got {raw!r}") from exc + if value < min_value: + raise ValueError(f"{name} must be >= {min_value}, got {value}") + return value + + class _ThreadTransferState: - def __init__(self, device: torch.device, use_cuda: bool) -> None: + def __init__( + self, + device: torch.device, + use_cuda: bool, + staging_slots: int, + ) -> None: self.device = device self.use_cuda = use_cuda self.pack_stream = None @@ -108,9 +146,9 @@ def __init__(self, device: torch.device, use_cuda: bool) -> None: with torch.cuda.device(device): self.pack_stream = torch.cuda.Stream() self.copy_stream = torch.cuda.Stream() - self.slots = [_StagingSlot(use_cuda), _StagingSlot(use_cuda)] + self.slots = [_StagingSlot(use_cuda) for _ in range(staging_slots)] else: - self.slots = [_StagingSlot(use_cuda), _StagingSlot(use_cuda)] + self.slots = [_StagingSlot(use_cuda) for _ in range(staging_slots)] def stream_ctx(self, stream): if stream is None: @@ -118,26 +156,141 @@ def stream_ctx(self, stream): return torch.cuda.stream(stream) +@dataclass(frozen=True) +class _TransferChunk: + memory_obj: Any + block_ids: list[int] + tensor: torch.Tensor + nbytes: int + + +@dataclass(frozen=True) +class _TransferGroup: + chunks: list[_TransferChunk] + nbytes: int + + class ATOMLMCacheGPUConnector: """LMCache GPUConnectorInterface for ATOM's opaque KV-block byte layout.""" - def __init__(self, codec: ATOMKVByteCodec, block_size: int) -> None: + def __init__( + self, + codec: ATOMKVByteCodec, + block_size: int, + *, + chunk_size: int | None = None, + ) -> None: self.codec = codec self.block_size = int(block_size) if self.block_size <= 0: raise ValueError("ATOM LMCache connector: block_size must be > 0") + self.chunk_size = int(chunk_size if chunk_size is not None else block_size) + if self.chunk_size <= 0: + raise ValueError("ATOM LMCache connector: chunk_size must be > 0") + if self.chunk_size % self.block_size != 0: + raise ValueError( + "LMCache chunk size must be divisible by ATOM KV block size: " + f"chunk_size={self.chunk_size}, block_size={self.block_size}" + ) + self._blocks_per_lmcache_chunk = self.chunk_size // self.block_size + self._gpu_staging_chunk_bytes = ( + self._blocks_per_lmcache_chunk * self.codec.bytes_per_block + ) + if self._gpu_staging_chunk_bytes <= 0: + raise ValueError( + "ATOM LMCache connector: GPU staging chunk bytes must be > 0" + ) self.device = torch.device(codec.device) self._tls = threading.local() + requested_group_chunks = _env_int("OFFLOAD_GPU_STAGING_CHUNKS", 2) + max_staging_bytes = _env_optional_int("OFFLOAD_GPU_STAGING_MAX_BYTES") + if max_staging_bytes is not None: + if max_staging_bytes < self._gpu_staging_chunk_bytes: + raise ValueError( + "OFFLOAD_GPU_STAGING_MAX_BYTES must be at least one " + "LMCache chunk: " + f"max_bytes={max_staging_bytes}, " + f"chunk_bytes={self._gpu_staging_chunk_bytes}" + ) + requested_group_chunks = min( + requested_group_chunks, + max_staging_bytes // self._gpu_staging_chunk_bytes, + ) + self._staging_group_chunks = max(1, int(requested_group_chunks)) + self._gpu_staging_capacity_bytes = ( + self._staging_group_chunks * self._gpu_staging_chunk_bytes + ) + self._staging_slots = _env_int("OFFLOAD_GPU_STAGING_SLOTS", 1) + self._release_gpu_staging_after_transfer = _env_flag( + "OFFLOAD_RELEASE_GPU_STAGING_AFTER_TRANSFER" + ) + + @property + def staging_slots(self) -> int: + return self._staging_slots + + @property + def gpu_staging_chunk_bytes(self) -> int: + return self._gpu_staging_chunk_bytes + + @property + def gpu_staging_group_chunks(self) -> int: + return self._staging_group_chunks + + @property + def gpu_staging_capacity_bytes(self) -> int: + return self._gpu_staging_capacity_bytes + + @property + def release_gpu_staging_after_transfer(self) -> bool: + return self._release_gpu_staging_after_transfer def _set_last_fastpath(self, value: str) -> None: self._tls.last_fastpath = value + def _set_last_transfer_stats( + self, + *, + chunks: int = 0, + max_chunk_bytes: int = 0, + groups: int = 0, + max_group_bytes: int = 0, + total_bytes: int = 0, + pack_ms: float = 0.0, + copy_ms: float = 0.0, + sync_ms: float = 0.0, + transfer_ms: float = 0.0, + ) -> None: + effective_gbps = 0.0 + if transfer_ms > 0 and total_bytes > 0: + effective_gbps = total_bytes / (transfer_ms * 1_000_000.0) + self._tls.last_transfer_stats = { + "chunks": int(chunks), + "max_chunk_bytes": int(max_chunk_bytes), + "groups": int(groups), + "max_group_bytes": int(max_group_bytes), + "total_bytes": int(total_bytes), + "gpu_staging_chunk_bytes": self._gpu_staging_chunk_bytes, + "gpu_staging_group_chunks": self._staging_group_chunks, + "gpu_staging_capacity_bytes": self._gpu_staging_capacity_bytes, + "gpu_staging_slots": self._staging_slots, + "pack_ms": float(pack_ms), + "copy_ms": float(copy_ms), + "sync_ms": float(sync_ms), + "transfer_ms": float(transfer_ms), + "effective_gbps": float(effective_gbps), + } + def reset_fastpath(self) -> None: self._set_last_fastpath("none") + self._set_last_transfer_stats() def last_fastpath(self) -> str: return getattr(self._tls, "last_fastpath", "unknown") + def last_transfer_stats(self) -> dict[str, int | float]: + return dict(getattr(self._tls, "last_transfer_stats", {})) + def _use_cuda(self) -> bool: return self.device.type == "cuda" @@ -149,15 +302,28 @@ def _thread_state(self) -> _ThreadTransferState: key = str(self.device) state = states.get(key) if state is None: - state = _ThreadTransferState(self.device, self._use_cuda()) + state = _ThreadTransferState( + self.device, + self._use_cuda(), + self._staging_slots, + ) states[key] = state return state def _ensure_slot(self, slot: _StagingSlot, nbytes: int) -> torch.Tensor: nbytes = int(nbytes) - if slot.tensor is None or int(slot.tensor.numel()) < nbytes: + if nbytes > self._gpu_staging_capacity_bytes: + raise RuntimeError( + "ATOM LMCache connector internal error: transfer group exceeds " + "bounded GPU staging capacity: " + f"nbytes={nbytes}, capacity={self._gpu_staging_capacity_bytes}" + ) + if ( + slot.tensor is None + or int(slot.tensor.numel()) != self._gpu_staging_capacity_bytes + ): slot.tensor = torch.empty( - (nbytes,), + (self._gpu_staging_capacity_bytes,), dtype=torch.uint8, device=self.device, ) @@ -169,24 +335,45 @@ def _next_slot(self, state: _ThreadTransferState) -> _StagingSlot: state.next_slot += 1 return slot + def _release_slot_if_requested(self, slot: _StagingSlot) -> None: + if not self._release_gpu_staging_after_transfer: + return + slot.tensor = None + slot.free_event_valid = False + def _ensure_host_tmp( self, state: _ThreadTransferState, nbytes: int, ) -> torch.Tensor: nbytes = int(nbytes) - if state.host_tmp is None or int(state.host_tmp.numel()) < nbytes: + if nbytes > self._gpu_staging_capacity_bytes: + raise RuntimeError( + "ATOM LMCache connector internal error: transfer group exceeds " + "bounded host staging capacity: " + f"nbytes={nbytes}, capacity={self._gpu_staging_capacity_bytes}" + ) + if ( + state.host_tmp is None + or int(state.host_tmp.numel()) != self._gpu_staging_capacity_bytes + ): if state.use_cuda: try: state.host_tmp = torch.empty( - (nbytes,), + (self._gpu_staging_capacity_bytes,), dtype=torch.uint8, pin_memory=True, ) except RuntimeError: - state.host_tmp = torch.empty((nbytes,), dtype=torch.uint8) + state.host_tmp = torch.empty( + (self._gpu_staging_capacity_bytes,), + dtype=torch.uint8, + ) else: - state.host_tmp = torch.empty((nbytes,), dtype=torch.uint8) + state.host_tmp = torch.empty( + (self._gpu_staging_capacity_bytes,), + dtype=torch.uint8, + ) return state.host_tmp[:nbytes] def _can_use_fused_chunk_major(self) -> bool: @@ -257,6 +444,119 @@ def _ranges_to_block_ids( for start, end in zip(starts, ends, strict=True) ] + def _iter_transfer_chunks( + self, + memory_objs: list[Any], + block_id_groups: list[list[int]], + ) -> list[_TransferChunk]: + chunks: list[_TransferChunk] = [] + for memory_obj, block_ids in zip(memory_objs, block_id_groups, strict=True): + block_count = len(block_ids) + if block_count == 0: + continue + nbytes = block_count * self.codec.bytes_per_block + if nbytes > self._gpu_staging_chunk_bytes: + raise ValueError( + "ATOM LMCache connector: single MemoryObj exceeds bounded " + "GPU staging chunk capacity; caller must pass LMCache " + "chunk-sized ranges: " + f"nbytes={nbytes}, capacity={self._gpu_staging_chunk_bytes}, " + f"blocks={block_count}, max_blocks=" + f"{self._blocks_per_lmcache_chunk}, chunk_size=" + f"{self.chunk_size}, block_size={self.block_size}" + ) + chunks.append( + _TransferChunk( + memory_obj=memory_obj, + block_ids=block_ids, + tensor=self._memory_tensor(memory_obj, nbytes), + nbytes=nbytes, + ) + ) + return chunks + + def _iter_transfer_groups( + self, + chunks: list[_TransferChunk], + ) -> list[_TransferGroup]: + groups: list[_TransferGroup] = [] + current: list[_TransferChunk] = [] + current_bytes = 0 + for chunk in chunks: + would_exceed_count = len(current) >= self._staging_group_chunks + would_exceed_bytes = ( + current_bytes + chunk.nbytes > self._gpu_staging_capacity_bytes + ) + if current and (would_exceed_count or would_exceed_bytes): + groups.append(_TransferGroup(chunks=current, nbytes=current_bytes)) + current = [] + current_bytes = 0 + current.append(chunk) + current_bytes += chunk.nbytes + if current: + groups.append(_TransferGroup(chunks=current, nbytes=current_bytes)) + return groups + + def _record_transfer_stats( + self, + chunks: list[_TransferChunk], + groups: list[_TransferGroup] | None = None, + *, + pack_ms: float = 0.0, + copy_ms: float = 0.0, + sync_ms: float = 0.0, + transfer_ms: float = 0.0, + ) -> None: + if groups is None: + groups = [] + total_bytes = sum(chunk.nbytes for chunk in chunks) + self._set_last_transfer_stats( + chunks=len(chunks), + max_chunk_bytes=max((chunk.nbytes for chunk in chunks), default=0), + groups=len(groups), + max_group_bytes=max((group.nbytes for group in groups), default=0), + total_bytes=total_bytes, + pack_ms=pack_ms, + copy_ms=copy_ms, + sync_ms=sync_ms, + transfer_ms=transfer_ms, + ) + + @staticmethod + def _group_block_ids(group: _TransferGroup) -> list[list[int]]: + return [chunk.block_ids for chunk in group.chunks] + + @staticmethod + def _slice_to_memory_objs(group: _TransferGroup, src_buf: torch.Tensor) -> None: + offset = 0 + for chunk in group.chunks: + chunk.tensor.copy_( + src_buf[offset : offset + chunk.nbytes], + non_blocking=chunk.tensor.device.type != "cpu", + ) + offset += chunk.nbytes + + @staticmethod + def _memory_objs_to_slice(group: _TransferGroup, dst_buf: torch.Tensor) -> None: + offset = 0 + for chunk in group.chunks: + dst_buf[offset : offset + chunk.nbytes].copy_( + chunk.tensor, + non_blocking=chunk.tensor.device.type != "cpu", + ) + offset += chunk.nbytes + + @staticmethod + def _remember_slot(used_slots: list[_StagingSlot], slot: _StagingSlot) -> None: + if not any(existing is slot for existing in used_slots): + used_slots.append(slot) + + def _release_slots_if_requested(self, used_slots: list[_StagingSlot]) -> None: + if not self._release_gpu_staging_after_transfer: + return + for slot in used_slots: + self._release_slot_if_requested(slot) + def from_gpu(self, memory_obj: Any, start: int, end: int, **kwargs) -> None: self.batched_from_gpu([memory_obj], [start], [end], **kwargs) @@ -270,90 +570,130 @@ def batched_from_gpu( ends: list[int], **kwargs, ) -> None: - """Pack ATOM KV blocks to LMCache MemoryObjs via double GPU staging.""" + """Pack ATOM KV blocks to LMCache MemoryObjs via bounded staging.""" if not (len(memory_objs) == len(starts) == len(ends)): raise ValueError("memory_objs, starts, and ends must have equal length") block_id_groups = self._ranges_to_block_ids(starts, ends, **kwargs) if not memory_objs: + self._set_last_transfer_stats() return state = self._thread_state() use_cuda = state.use_cuda - chunk_block_counts = [len(block_ids) for block_ids in block_id_groups] - all_block_ids = [ - block_id for block_ids in block_id_groups for block_id in block_ids - ] - total_nbytes = len(all_block_ids) * self.codec.bytes_per_block - if total_nbytes == 0: + chunks = self._iter_transfer_chunks(memory_objs, block_id_groups) + groups = self._iter_transfer_groups(chunks) + self._record_transfer_stats(chunks, groups) + if not chunks: return - slot = self._next_slot(state) - device_buf = self._ensure_slot(slot, total_nbytes) - dst_tensors = [ - self._memory_tensor( - memory_obj, - block_count * self.codec.bytes_per_block, - ) - for memory_obj, block_count in zip( - memory_objs, - chunk_block_counts, - strict=True, - ) - ] - if self._can_use_fused_chunk_major(): self._set_last_fastpath("fused_chunk") - if slot.free_event_valid: - state.pack_stream.wait_event(slot.free_event) - with state.stream_ctx(state.pack_stream): - self.codec.gpu_to_chunk_major_device_buffer( - device_buf, - block_id_groups, - stream=state.pack_stream, - ) - slot.ready_event.record(state.pack_stream) - state.copy_stream.wait_event(slot.ready_event) - with state.stream_ctx(state.copy_stream): - offset = 0 - for dst, block_count in zip( - dst_tensors, - chunk_block_counts, - strict=True, - ): - nbytes = block_count * self.codec.bytes_per_block - dst.copy_( - device_buf[offset : offset + nbytes], - non_blocking=True, - ) - offset += nbytes - slot.free_event.record(state.copy_stream) - slot.free_event_valid = True - state.copy_stream.synchronize() + used_slots: list[_StagingSlot] = [] + pack_ms = 0.0 + copy_ms = 0.0 + sync_ms = 0.0 + t_total0 = time.perf_counter() + try: + for group in groups: + slot = self._next_slot(state) + self._remember_slot(used_slots, slot) + device_buf = self._ensure_slot(slot, group.nbytes) + if slot.free_event_valid: + state.pack_stream.wait_event(slot.free_event) + t0 = time.perf_counter() + with state.stream_ctx(state.pack_stream): + self.codec.gpu_to_chunk_major_device_buffer( + device_buf, + self._group_block_ids(group), + stream=state.pack_stream, + ) + pack_ms += (time.perf_counter() - t0) * 1000 + slot.ready_event.record(state.pack_stream) + state.copy_stream.wait_event(slot.ready_event) + t0 = time.perf_counter() + with state.stream_ctx(state.copy_stream): + self._slice_to_memory_objs(group, device_buf) + copy_ms += (time.perf_counter() - t0) * 1000 + slot.free_event.record(state.copy_stream) + slot.free_event_valid = True + t0 = time.perf_counter() + state.copy_stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + except Exception: + for slot in used_slots: + slot.free_event_valid = False + raise + finally: + self._release_slots_if_requested(used_slots) + self._record_transfer_stats( + chunks, + groups, + pack_ms=pack_ms, + copy_ms=copy_ms, + sync_ms=sync_ms, + transfer_ms=(time.perf_counter() - t_total0) * 1000, + ) return self._set_last_fastpath("chunk") - host_buf = self._ensure_host_tmp(state, total_nbytes) - if use_cuda: - if slot.free_event_valid: - state.pack_stream.wait_event(slot.free_event) - with state.stream_ctx(state.pack_stream): - self.codec.gpu_to_device_buffer( - device_buf, - all_block_ids, - stream=state.pack_stream, - ) - slot.ready_event.record(state.pack_stream) - state.copy_stream.wait_event(slot.ready_event) - with state.stream_ctx(state.copy_stream): - host_buf.copy_(device_buf, non_blocking=True) - slot.free_event.record(state.copy_stream) - slot.free_event_valid = True - state.copy_stream.synchronize() - else: - self.codec.gpu_to_device_buffer(device_buf, all_block_ids) - host_buf.copy_(device_buf, non_blocking=False) - - self.codec.split_request_buffer(host_buf, dst_tensors, chunk_block_counts) + pack_ms = 0.0 + copy_ms = 0.0 + sync_ms = 0.0 + t_total0 = time.perf_counter() + used_slots: list[_StagingSlot] = [] + for group in groups: + slot = self._next_slot(state) + self._remember_slot(used_slots, slot) + device_buf = self._ensure_slot(slot, group.nbytes) + host_buf = self._ensure_host_tmp(state, group.nbytes) + try: + if use_cuda: + if slot.free_event_valid: + state.pack_stream.wait_event(slot.free_event) + t0 = time.perf_counter() + with state.stream_ctx(state.pack_stream): + self.codec.gpu_to_chunk_major_device_buffer( + device_buf, + self._group_block_ids(group), + stream=state.pack_stream, + ) + pack_ms += (time.perf_counter() - t0) * 1000 + slot.ready_event.record(state.pack_stream) + state.copy_stream.wait_event(slot.ready_event) + t0 = time.perf_counter() + with state.stream_ctx(state.copy_stream): + host_buf.copy_(device_buf, non_blocking=True) + copy_ms += (time.perf_counter() - t0) * 1000 + slot.free_event.record(state.copy_stream) + slot.free_event_valid = True + t0 = time.perf_counter() + state.copy_stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + else: + t0 = time.perf_counter() + self.codec.gpu_to_chunk_major_device_buffer( + device_buf, + self._group_block_ids(group), + ) + pack_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + host_buf.copy_(device_buf, non_blocking=False) + copy_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self._slice_to_memory_objs(group, host_buf) + copy_ms += (time.perf_counter() - t0) * 1000 + except Exception: + slot.free_event_valid = False + raise + self._release_slots_if_requested(used_slots) + self._record_transfer_stats( + chunks, + groups, + pack_ms=pack_ms, + copy_ms=copy_ms, + sync_ms=sync_ms, + transfer_ms=(time.perf_counter() - t_total0) * 1000, + ) def batched_to_gpu( self, @@ -362,88 +702,129 @@ def batched_to_gpu( ends: list[int] | None = None, **kwargs, ) -> None: - """Load LMCache MemoryObjs back into ATOM KV blocks via double staging.""" + """Load LMCache MemoryObjs back into ATOM KV blocks via bounded staging.""" if memory_objs is None or starts is None or ends is None: raise ValueError("memory_objs, starts, and ends are required") if not (len(memory_objs) == len(starts) == len(ends)): raise ValueError("memory_objs, starts, and ends must have equal length") block_id_groups = self._ranges_to_block_ids(starts, ends, **kwargs) if not memory_objs: + self._set_last_transfer_stats() return state = self._thread_state() use_cuda = state.use_cuda - chunk_block_counts = [len(block_ids) for block_ids in block_id_groups] - all_block_ids = [ - block_id for block_ids in block_id_groups for block_id in block_ids - ] - total_nbytes = len(all_block_ids) * self.codec.bytes_per_block - if total_nbytes == 0: + chunks = self._iter_transfer_chunks(memory_objs, block_id_groups) + groups = self._iter_transfer_groups(chunks) + self._record_transfer_stats(chunks, groups) + if not chunks: return - slot = self._next_slot(state) - device_buf = self._ensure_slot(slot, total_nbytes) - src_tensors = [ - self._memory_tensor( - memory_obj, - block_count * self.codec.bytes_per_block, - ) - for memory_obj, block_count in zip( - memory_objs, - chunk_block_counts, - strict=True, - ) - ] - if self._can_use_fused_chunk_major(): self._set_last_fastpath("fused_chunk") - if slot.free_event_valid: - state.copy_stream.wait_event(slot.free_event) - with state.stream_ctx(state.copy_stream): - offset = 0 - for src, block_count in zip( - src_tensors, - chunk_block_counts, - strict=True, - ): - nbytes = block_count * self.codec.bytes_per_block - device_buf[offset : offset + nbytes].copy_( - src, - non_blocking=True, - ) - offset += nbytes - slot.ready_event.record(state.copy_stream) - state.pack_stream.wait_event(slot.ready_event) - with state.stream_ctx(state.pack_stream): - self.codec.chunk_major_device_buffer_to_gpu( - device_buf, - block_id_groups, - stream=state.pack_stream, - ) - slot.free_event.record(state.pack_stream) - slot.free_event_valid = True - state.pack_stream.synchronize() + used_slots: list[_StagingSlot] = [] + copy_ms = 0.0 + pack_ms = 0.0 + sync_ms = 0.0 + t_total0 = time.perf_counter() + try: + for group in groups: + slot = self._next_slot(state) + self._remember_slot(used_slots, slot) + device_buf = self._ensure_slot(slot, group.nbytes) + if slot.free_event_valid: + state.copy_stream.wait_event(slot.free_event) + t0 = time.perf_counter() + with state.stream_ctx(state.copy_stream): + self._memory_objs_to_slice(group, device_buf) + copy_ms += (time.perf_counter() - t0) * 1000 + slot.ready_event.record(state.copy_stream) + state.pack_stream.wait_event(slot.ready_event) + t0 = time.perf_counter() + with state.stream_ctx(state.pack_stream): + self.codec.chunk_major_device_buffer_to_gpu( + device_buf, + self._group_block_ids(group), + stream=state.pack_stream, + ) + pack_ms += (time.perf_counter() - t0) * 1000 + slot.free_event.record(state.pack_stream) + slot.free_event_valid = True + t0 = time.perf_counter() + state.pack_stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + except Exception: + for slot in used_slots: + slot.free_event_valid = False + raise + finally: + self._release_slots_if_requested(used_slots) + self._record_transfer_stats( + chunks, + groups, + pack_ms=pack_ms, + copy_ms=copy_ms, + sync_ms=sync_ms, + transfer_ms=(time.perf_counter() - t_total0) * 1000, + ) return self._set_last_fastpath("chunk") - host_buf = self._ensure_host_tmp(state, total_nbytes) - self.codec.stitch_chunk_buffers(host_buf, src_tensors, chunk_block_counts) - if use_cuda: - if slot.free_event_valid: - state.copy_stream.wait_event(slot.free_event) - with state.stream_ctx(state.copy_stream): - device_buf.copy_(host_buf, non_blocking=True) - slot.ready_event.record(state.copy_stream) - state.pack_stream.wait_event(slot.ready_event) - with state.stream_ctx(state.pack_stream): - self.codec.device_buffer_to_gpu( - device_buf, - all_block_ids, - stream=state.pack_stream, - ) - slot.free_event.record(state.pack_stream) - slot.free_event_valid = True - state.pack_stream.synchronize() - else: - device_buf.copy_(host_buf, non_blocking=False) - self.codec.device_buffer_to_gpu(device_buf, all_block_ids) + copy_ms = 0.0 + pack_ms = 0.0 + sync_ms = 0.0 + t_total0 = time.perf_counter() + used_slots: list[_StagingSlot] = [] + for group in groups: + slot = self._next_slot(state) + self._remember_slot(used_slots, slot) + device_buf = self._ensure_slot(slot, group.nbytes) + host_buf = self._ensure_host_tmp(state, group.nbytes) + try: + t0 = time.perf_counter() + self._memory_objs_to_slice(group, host_buf) + copy_ms += (time.perf_counter() - t0) * 1000 + if use_cuda: + if slot.free_event_valid: + state.copy_stream.wait_event(slot.free_event) + t0 = time.perf_counter() + with state.stream_ctx(state.copy_stream): + device_buf.copy_(host_buf, non_blocking=True) + copy_ms += (time.perf_counter() - t0) * 1000 + slot.ready_event.record(state.copy_stream) + state.pack_stream.wait_event(slot.ready_event) + t0 = time.perf_counter() + with state.stream_ctx(state.pack_stream): + self.codec.chunk_major_device_buffer_to_gpu( + device_buf, + self._group_block_ids(group), + stream=state.pack_stream, + ) + pack_ms += (time.perf_counter() - t0) * 1000 + slot.free_event.record(state.pack_stream) + slot.free_event_valid = True + t0 = time.perf_counter() + state.pack_stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + else: + t0 = time.perf_counter() + device_buf.copy_(host_buf, non_blocking=False) + copy_ms += (time.perf_counter() - t0) * 1000 + t0 = time.perf_counter() + self.codec.chunk_major_device_buffer_to_gpu( + device_buf, + self._group_block_ids(group), + ) + pack_ms += (time.perf_counter() - t0) * 1000 + except Exception: + slot.free_event_valid = False + raise + self._release_slots_if_requested(used_slots) + self._record_transfer_stats( + chunks, + groups, + pack_ms=pack_ms, + copy_ms=copy_ms, + sync_ms=sync_ms, + transfer_ms=(time.perf_counter() - t_total0) * 1000, + ) diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index b4c6fd28ab..6784733af9 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -360,6 +360,55 @@ def compute_block_bytes(self) -> int: ) return block_bytes + def compute_offload_staging_block_bytes(self) -> int: + """Per exposed KVCacheTensor block copied by ATOM offload. + + ``compute_block_bytes()`` is a scheduler-block KV pool estimate. + The offload codec copies slices exposed by ``build_kv_cache_tensor``; + for the AITER MHA layout those slices are physical KV blocks. + """ + from aiter import dtypes + + runner = self.model_runner + config = runner.config + hf_config = config.hf_config + kv_dtype_size = dtypes.d_dtypes[config.kv_cache_dtype].itemsize + physical_block_size = runner.physical_block_size + + def per_layer_bytes(num_kv_heads: int) -> int: + block_bytes = ( + 2 + * num_kv_heads + * hf_config.head_dim + * physical_block_size + * kv_dtype_size + ) + if config.kv_cache_dtype == "fp8": + block_bytes += 2 * num_kv_heads * physical_block_size * 4 + return block_bytes + + if runner.is_mimo_v2(): + pattern = hf_config.hybrid_layer_pattern + num_swa_layers = sum( + 1 for i in range(hf_config.num_hidden_layers) if pattern[i] == 1 + ) + num_full_layers = hf_config.num_hidden_layers - num_swa_layers + num_draft_layers = ( + runner._get_total_num_layers() - hf_config.num_hidden_layers + ) + num_swa_layers += num_draft_layers + _swa_raw = getattr(hf_config, "swa_num_key_value_heads", 0) + swa_kv_heads = ( + _swa_raw // runner.world_size + if _swa_raw >= runner.world_size + else (1 if _swa_raw else 0) + ) + return num_full_layers * per_layer_bytes( + runner._get_num_kv_heads() + ) + num_swa_layers * per_layer_bytes(swa_kv_heads) + + return hf_config.num_hidden_layers * per_layer_bytes(runner._get_num_kv_heads()) + def allocate_kv_cache_tensors( self, num_kv_heads: int, num_draft_layers: int ) -> dict: diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 05735599dd..811af65b51 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -585,6 +585,14 @@ def compute_block_bytes(self) -> int: ) return block_bytes + def compute_offload_staging_block_bytes(self) -> int: + """Per exposed MLA KVCacheTensor block copied by ATOM offload.""" + runner = self.model_runner + config = runner.config + total_num_layers = runner._get_total_num_layers() + kv_dtype_size = dtypes.d_dtypes[config.kv_cache_dtype].itemsize + return total_num_layers * 576 * kv_dtype_size + def allocate_kv_cache_tensors( self, num_kv_heads: int, num_draft_layers: int ) -> dict: diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index 36ac0daf32..7c87abb184 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -167,6 +167,16 @@ def compute_block_bytes(self) -> int: """ return 0 + def compute_offload_staging_block_bytes(self) -> int: + """Bytes copied by ATOM KV offload for one block id. + + This must match ``ATOMKVByteCodec.bytes_per_block`` without requiring + the actual KV tensors to exist yet. The default mirrors the KV pool + block estimate; backends whose exposed ``KVCacheTensor`` block geometry + differs from the scheduler block geometry should override it. + """ + return self.compute_block_bytes() + def allocate_kv_cache_tensors( self, num_kv_heads: int, num_draft_layers: int ) -> dict[str, Any]: diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 43346b7b5c..f2e69713f8 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -281,7 +281,7 @@ def test_lmcache_connector_maps_token_ranges_to_block_ids(monkeypatch): ) } codec = ATOMKVByteCodec(kv_caches) - connector = ATOMLMCacheGPUConnector(codec, block_size=4) + connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=8) memory_obj = SimpleNamespace( tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) ) @@ -318,6 +318,7 @@ def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): pytest.skip("real torch is unavailable") monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "2") original = { "l0": SimpleNamespace( k_cache=torch.arange(6 * 2, dtype=torch.uint8).reshape(6, 2), @@ -335,24 +336,39 @@ def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): ) } codec = ATOMKVByteCodec(kv_caches) - connector = ATOMLMCacheGPUConnector(codec, block_size=4) + connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=8) monkeypatch.setattr(connector, "_can_use_fused_chunk_major", lambda: True) orig_pack = codec.gpu_to_chunk_major_device_buffer orig_unpack = codec.chunk_major_device_buffer_to_gpu + + pack_groups = [] + unpack_groups = [] + slot_requests = [] + monkeypatch.setattr( codec, "gpu_to_chunk_major_device_buffer", - lambda device_buf, block_id_groups, stream=None: orig_pack( - device_buf, block_id_groups, stream=None - ), + lambda device_buf, block_id_groups, stream=None: ( + pack_groups.append([list(group) for group in block_id_groups]), + orig_pack(device_buf, block_id_groups, stream=None), + )[-1], ) monkeypatch.setattr( codec, "chunk_major_device_buffer_to_gpu", - lambda device_buf, block_id_groups, stream=None: orig_unpack( - device_buf, block_id_groups, stream=None - ), + lambda device_buf, block_id_groups, stream=None: ( + unpack_groups.append([list(group) for group in block_id_groups]), + orig_unpack(device_buf, block_id_groups, stream=None), + )[-1], ) + orig_ensure_slot = connector._ensure_slot + + def _ensure_slot(slot, nbytes): + device_buf = orig_ensure_slot(slot, nbytes) + slot_requests.append((nbytes, int(slot.tensor.numel()))) + return device_buf + + monkeypatch.setattr(connector, "_ensure_slot", _ensure_slot) class _FakeEvent: def record(self, stream) -> None: @@ -420,6 +436,20 @@ def stream_ctx(self, stream): ] ) assert connector.last_fastpath() == "fused_chunk" + transfer_stats = connector.last_transfer_stats() + assert transfer_stats["chunks"] == 2 + assert transfer_stats["groups"] == 1 + assert transfer_stats["max_chunk_bytes"] == 2 * codec.bytes_per_block + assert transfer_stats["max_group_bytes"] == 3 * codec.bytes_per_block + assert transfer_stats["total_bytes"] == 3 * codec.bytes_per_block + assert transfer_stats["gpu_staging_chunk_bytes"] == 2 * codec.bytes_per_block + assert transfer_stats["gpu_staging_group_chunks"] == 2 + assert transfer_stats["gpu_staging_capacity_bytes"] == 4 * codec.bytes_per_block + assert transfer_stats["gpu_staging_slots"] == 1 + assert transfer_stats["transfer_ms"] >= 0 + assert pack_groups == [[[1, 2], [3]]] + assert all(nbytes <= 4 * codec.bytes_per_block for nbytes, _ in slot_requests) + assert all(capacity == 4 * codec.bytes_per_block for _, capacity in slot_requests) assert torch.equal(memory_objs[0].tensor, expected0) assert torch.equal(memory_objs[1].tensor, expected1) @@ -433,6 +463,7 @@ def stream_ctx(self, stream): ) assert connector.last_fastpath() == "fused_chunk" + assert unpack_groups == [[[1, 2], [3]]] for bid in [1, 2, 3]: assert torch.equal(kv_caches["l0"].k_cache[bid], original["l0"].k_cache[bid]) assert torch.equal(kv_caches["l0"].v_cache[bid], original["l0"].v_cache[bid]) @@ -440,6 +471,214 @@ def stream_ctx(self, stream): assert torch.count_nonzero(kv_caches["l0"].v_cache[0]) == 0 +def test_lmcache_connector_fallback_staging_is_chunk_bounded(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "1") + original = { + "l0": SimpleNamespace( + k_cache=torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2), + v_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 51), + k_scale=None, + v_scale=None, + ) + } + kv_caches = { + "l0": SimpleNamespace( + k_cache=original["l0"].k_cache.clone(), + v_cache=original["l0"].v_cache.clone(), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=8) + cap = 2 * codec.bytes_per_block + slot_requests = [] + host_requests = [] + orig_ensure_slot = connector._ensure_slot + orig_ensure_host_tmp = connector._ensure_host_tmp + + def _ensure_slot(slot, nbytes): + device_buf = orig_ensure_slot(slot, nbytes) + slot_requests.append((nbytes, int(slot.tensor.numel()))) + return device_buf + + def _ensure_host_tmp(state, nbytes): + host_buf = orig_ensure_host_tmp(state, nbytes) + host_requests.append((nbytes, int(state.host_tmp.numel()))) + return host_buf + + monkeypatch.setattr(connector, "_ensure_slot", _ensure_slot) + monkeypatch.setattr(connector, "_ensure_host_tmp", _ensure_host_tmp) + memory_objs = [ + SimpleNamespace( + tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) + ), + SimpleNamespace( + tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) + ), + SimpleNamespace( + tensor=torch.empty(1 * codec.bytes_per_block, dtype=torch.uint8) + ), + ] + + connector.batched_from_gpu( + memory_objs, + [0, 8, 16], + [8, 16, 20], + block_ids=list(range(8)), + ) + + assert connector.last_fastpath() == "chunk" + assert connector.last_transfer_stats()["chunks"] == 3 + assert connector.last_transfer_stats()["max_chunk_bytes"] == cap + assert all(nbytes <= cap for nbytes, _ in slot_requests) + assert all(capacity == cap for _, capacity in slot_requests) + assert all(nbytes <= cap for nbytes, _ in host_requests) + assert all(capacity == cap for _, capacity in host_requests) + + kv_caches["l0"].k_cache.zero_() + kv_caches["l0"].v_cache.zero_() + connector.batched_to_gpu( + memory_objs, + [0, 8, 16], + [8, 16, 20], + block_ids=list(range(8)), + ) + + for bid in range(5): + assert torch.equal(kv_caches["l0"].k_cache[bid], original["l0"].k_cache[bid]) + assert torch.equal(kv_caches["l0"].v_cache[bid], original["l0"].v_cache[bid]) + assert torch.count_nonzero(kv_caches["l0"].k_cache[5]) == 0 + assert torch.count_nonzero(kv_caches["l0"].v_cache[5]) == 0 + + +def test_lmcache_connector_release_covers_fallback_chunks(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "1") + monkeypatch.setenv("OFFLOAD_RELEASE_GPU_STAGING_AFTER_TRANSFER", "1") + kv_caches = { + "l0": SimpleNamespace( + k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + v_cache=(torch.arange(4 * 3, dtype=torch.uint8).reshape(4, 3) + 51), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=8) + memory_objs = [ + SimpleNamespace( + tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) + ), + SimpleNamespace( + tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) + ), + ] + + connector.batched_from_gpu( + memory_objs, + [0, 8], + [8, 16], + block_ids=list(range(4)), + ) + + state = connector._thread_state() + assert state.host_tmp is not None + assert int(state.host_tmp.numel()) == 2 * codec.bytes_per_block + assert all(slot.tensor is None for slot in state.slots) + + +def test_lmcache_connector_rejects_oversized_memory_obj(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + kv_caches = { + "l0": SimpleNamespace( + k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), + v_cache=(torch.arange(4 * 3, dtype=torch.uint8).reshape(4, 3) + 51), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=4) + memory_obj = SimpleNamespace( + tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) + ) + + with pytest.raises(ValueError, match="single MemoryObj exceeds"): + connector.batched_from_gpu( + [memory_obj], + [0], + [8], + block_ids=list(range(4)), + ) + + +def test_lmcache_connector_respects_staging_slot_env(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + monkeypatch.setenv("OFFLOAD_GPU_STAGING_SLOTS", "2") + monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "3") + kv_caches = { + "l0": SimpleNamespace( + k_cache=torch.arange(2 * 2, dtype=torch.uint8).reshape(2, 2), + v_cache=torch.arange(2 * 3, dtype=torch.uint8).reshape(2, 3), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=4) + + assert connector.staging_slots == 2 + assert connector.gpu_staging_group_chunks == 3 + assert connector.gpu_staging_capacity_bytes == 3 * connector.gpu_staging_chunk_bytes + assert len(connector._thread_state().slots) == 2 + + +def test_lmcache_connector_default_staging_group_chunks_is_two(monkeypatch): + import torch + + if not hasattr(torch, "arange"): + pytest.skip("real torch is unavailable") + + monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") + monkeypatch.delenv("OFFLOAD_GPU_STAGING_CHUNKS", raising=False) + monkeypatch.delenv("OFFLOAD_GPU_STAGING_MAX_BYTES", raising=False) + kv_caches = { + "l0": SimpleNamespace( + k_cache=torch.arange(2 * 2, dtype=torch.uint8).reshape(2, 2), + v_cache=torch.arange(2 * 3, dtype=torch.uint8).reshape(2, 3), + k_scale=None, + v_scale=None, + ) + } + codec = ATOMKVByteCodec(kv_caches) + connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=4) + + assert connector.gpu_staging_group_chunks == 2 + assert connector.gpu_staging_capacity_bytes == 2 * connector.gpu_staging_chunk_bytes + + def test_codec_chunk_major_device_buffer_layout(monkeypatch): import torch From 5084d247ce04639ff877a353e89700c0218ff52d Mon Sep 17 00:00:00 2001 From: yihonglie Date: Wed, 3 Jun 2026 22:45:36 -0500 Subject: [PATCH 15/27] Remove obsolete offload staging fallback code --- atom/kv_transfer/offload/connector.py | 35 +-- atom/kv_transfer/offload/gpu_connector.py | 311 +-------------------- atom/kv_transfer/offload/native_stitch.cpp | 150 ---------- atom/kv_transfer/offload/native_stitch.py | 56 ---- tests/test_lmcache_offload_connector.py | 226 +-------------- 5 files changed, 13 insertions(+), 765 deletions(-) delete mode 100644 atom/kv_transfer/offload/native_stitch.cpp delete mode 100644 atom/kv_transfer/offload/native_stitch.py diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index cb322296a4..9a0387f696 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -87,15 +87,12 @@ def __init__(self, config) -> None: self._save_executor = ThreadPoolExecutor( max_workers=n_save_workers, thread_name_prefix="lmc-offload-save" ) - self._tls = threading.local() # per-thread copy stream self._lock = threading.Lock() self._done_load: set[ReqId] = set() self._done_save: set[ReqId] = set() self._failed_load: set[ReqId] = set() self._engine = None - self._sm = None - self._tdb = None self._codec: ATOMKVByteCodec | None = None self._lookup_server = None @@ -140,8 +137,6 @@ def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: # opaque uint8 object, so keep a supported tensor MemoryFormat. self._engine.fmt = MemoryFormat.KV_2LTD self._engine.post_init() - self._sm = self._engine.storage_manager - self._tdb = self._engine.token_database # DEBUG: wrap engine.lookup to capture EVERY call (incl. the ones the ZMQ # lookup_server makes on behalf of the scheduler) — args + result. @@ -182,7 +177,7 @@ def _logged_lookup(*a, **k): rank, self._codec.bytes_per_block, self.chunk_size, - self._codec.layout, + "chunk_major", gpu_connector.staging_slots, gpu_connector.gpu_staging_chunk_bytes, gpu_connector.gpu_staging_group_chunks, @@ -243,34 +238,6 @@ def _lookup_unpin(self, req_id) -> None: except Exception: pass - def _copy_device(self) -> torch.device | None: - codec = getattr(self, "_codec", None) - device = getattr(codec, "device", None) - if device is None: - return None - device = torch.device(device) - if device.type != "cuda": - return None - return device - - def _stream(self) -> torch.cuda.Stream: - """A CUDA stream owned by the calling copy-daemon thread and device.""" - device = self._copy_device() - key = str(device) if device is not None else "default" - streams = getattr(self._tls, "streams", None) - if streams is None: - streams = {} - self._tls.streams = streams - s = streams.get(key) - if s is None: - if device is None: - s = torch.cuda.Stream() - else: - with torch.cuda.device(device): - s = torch.cuda.Stream() - streams[key] = s - return s - def _profile_enabled(self) -> bool: return os.environ.get("OFFLOAD_PROFILE", "1").lower() not in ( "0", diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 8ccad58ee8..8cb50400e2 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -"""AITER-layout-aware byte codec between ATOM's paged GPU KV cache and a flat -pinned host buffer (an LMCache ``MemoryObj``'s ``uint8`` tensor). +"""AITER-layout-aware byte codec between ATOM's paged GPU KV cache and flat +``uint8`` staging buffers. Why a byte codec instead of an LMCache ``GPUConnectorInterface`` subclass: LMCache's ``engine.store/retrieve`` GPU path only emits token-major formats @@ -19,11 +19,11 @@ A whole *block* of any per-layer cache tensor (``t[block_id]``) is contiguous, so a block's KV is a set of contiguous byte slices: per layer K, V, and (fp8) k_scale, -v_scale. The flat per-block layout in the host buffer is:: +v_scale. The canonical staging layout for one chunk is segment-major:: - [ L0.K | L0.V | L0.kS | L0.vS | L1.K | L1.V | ... ] (only present segments) + [ all L0.K blocks | all L0.V blocks | all L0.kS blocks | ... ] -which is self-consistent for store and load (we never reinterpret it). +and batched transfers concatenate those per-chunk buffers for LMCache MemoryObjs. """ from __future__ import annotations @@ -31,7 +31,6 @@ import logging import operator import os -import threading import torch @@ -78,34 +77,8 @@ def __init__(self, kv_caches: dict) -> None: self._seg_block_bytes: list[int] = [ int(t[0].numel()) * t.element_size() for t in self._segments ] - # Byte offset of each segment within one block's flat record. - self._seg_off: list[int] = [] - acc = 0 - for nb in self._seg_block_bytes: - self._seg_off.append(acc) - acc += nb - self.bytes_per_block: int = acc - self.layout = os.environ.get("OFFLOAD_CODEC_LAYOUT", "block").lower() - if self.layout not in ("block", "segment", "segment_indexed"): - self.layout = "block" - self._tls = threading.local() - self._native_stitch = None - self._native_split = None + self.bytes_per_block: int = sum(self._seg_block_bytes) self._native_kv_staging = None - if self.layout == "segment_indexed" and os.environ.get( - "OFFLOAD_NATIVE_STITCH", "0" - ).lower() not in ("0", "false", "no", "off"): - try: - from atom.kv_transfer.offload import native_stitch - - native_stitch.load_extension() - self._native_stitch = native_stitch.stitch_chunk_buffers - self._native_split = native_stitch.split_request_buffer - except Exception: - logger.warning( - "ATOMKVByteCodec: native stitch unavailable; using torch stitch", - exc_info=True, - ) if self._device.type == "cuda" and os.environ.get( "OFFLOAD_NATIVE_KV_STAGING", "0" ).lower() not in ("0", "false", "no", "off"): @@ -121,10 +94,6 @@ def __init__(self, kv_caches: dict) -> None: exc_info=True, ) - @property - def segments_per_block(self) -> int: - return len(self._segments) - @property def device(self) -> torch.device: return self._device @@ -133,61 +102,7 @@ def device(self) -> torch.device: def has_native_chunk_major_staging(self) -> bool: return self._native_kv_staging is not None - def copy_calls_for_blocks(self, nblocks: int) -> int: - return int(nblocks) * len(self._segments) - - def copy_calls_for_block_ids(self, block_ids: list[int]) -> int: - if self.layout == "block": - return self.copy_calls_for_blocks(len(block_ids)) - if self.layout == "segment_indexed": - return len(self._segments) * 2 - return len(self._segments) * len(list(self._contiguous_runs(block_ids))) - # -- helpers ---------------------------------------------------------- - @staticmethod - def _block_bytes_view(seg: torch.Tensor, block_id: int) -> torch.Tensor: - """Flat ``uint8`` view of one contiguous block slice (no copy).""" - blk = seg[block_id] - if not blk.is_contiguous(): - # Block slices of the paged cache are contiguous in practice; guard - # anyway. A non-contiguous block would break in-place H2D, so fail loud. - raise RuntimeError("ATOMKVByteCodec: block slice not contiguous") - return blk.reshape(-1).view(torch.uint8) - - @staticmethod - def _blocks_bytes_view( - seg: torch.Tensor, - block_id: int, - nblocks: int, - ) -> torch.Tensor: - """Flat ``uint8`` view of a contiguous block range (no copy).""" - blk = seg[block_id : block_id + nblocks] - if not blk.is_contiguous(): - raise RuntimeError("ATOMKVByteCodec: block range not contiguous") - return blk.reshape(-1).view(torch.uint8) - - @staticmethod - def _contiguous_runs(block_ids: list[int]): - """Yield ``(logical_start, physical_start, run_len)`` for increasing - physical block-id runs in logical order.""" - if not block_ids: - return - logical_start = 0 - physical_start = block_ids[0] - prev = block_ids[0] - run_len = 1 - for logical_idx, bid in enumerate(block_ids[1:], start=1): - if bid == prev + 1: - prev = bid - run_len += 1 - continue - yield logical_start, physical_start, run_len - logical_start = logical_idx - physical_start = bid - prev = bid - run_len = 1 - yield logical_start, physical_start, run_len - def _segment_bases(self, nblocks: int) -> list[int]: bases = [] acc = 0 @@ -231,17 +146,6 @@ def _normalize_block_id_groups( raise ValueError("ATOMKVByteCodec: duplicate block ids are not supported") return groups, flat, [len(block_ids) for block_ids in groups] - def _validate_host_buf(self, host_buf: torch.Tensor, nblocks: int) -> None: - if host_buf.dtype != torch.uint8: - raise TypeError("ATOMKVByteCodec: host_buf must be a uint8 tensor") - required = int(nblocks) * self.bytes_per_block - if int(host_buf.numel()) < required: - raise ValueError( - "ATOMKVByteCodec: host_buf is too small " - f"for {nblocks} blocks; need {required} bytes, " - f"got {int(host_buf.numel())}" - ) - def _validate_device_buf(self, device_buf: torch.Tensor, nblocks: int) -> None: if device_buf.dtype != torch.uint8: raise TypeError("ATOMKVByteCodec: device_buf must be a uint8 tensor") @@ -258,100 +162,6 @@ def _validate_device_buf(self, device_buf: torch.Tensor, nblocks: int) -> None: f"got {int(device_buf.numel())}" ) - def stitch_chunk_buffers( - self, - dst: torch.Tensor, - chunk_buffers: list[torch.Tensor], - chunk_block_counts: list[int], - ) -> None: - """CPU-side stitch from per-LMCache-chunk segment-major buffers into one - request-level segment-major buffer. - - Each stored chunk is laid out as ``[seg0 chunk_blocks | seg1 ...]``. A - single request-level indexed H2D scatter expects - ``[seg0 all_blocks | seg1 all_blocks | ...]``. - """ - if self._native_stitch is not None: - self._native_stitch( - dst, - chunk_buffers, - chunk_block_counts, - self._seg_block_bytes, - ) - return - total_blocks = sum(chunk_block_counts) - dst_bases = self._segment_bases(total_blocks) - src_bases_by_chunk = [ - self._segment_bases(nblocks) for nblocks in chunk_block_counts - ] - for seg_idx, (dst_base, nb) in enumerate(zip(dst_bases, self._seg_block_bytes)): - parts = [ - src[bases[seg_idx] : bases[seg_idx] + nblocks * nb] - for src, bases, nblocks in zip( - chunk_buffers, src_bases_by_chunk, chunk_block_counts - ) - ] - torch.cat( - parts, - out=dst[dst_base : dst_base + total_blocks * nb], - ) - - def split_request_buffer( - self, - src: torch.Tensor, - chunk_buffers: list[torch.Tensor], - chunk_block_counts: list[int], - ) -> None: - """CPU-side inverse of :meth:`stitch_chunk_buffers`. - - ``src`` is one request-level segment-major buffer - ``[seg0 all_blocks | seg1 all_blocks | ...]``. Each destination chunk - receives its own segment-major slice - ``[seg0 chunk_blocks | seg1 chunk_blocks | ...]`` for LMCache storage. - """ - if self._native_split is not None: - self._native_split( - src, - chunk_buffers, - chunk_block_counts, - self._seg_block_bytes, - ) - return - total_blocks = sum(chunk_block_counts) - src_bases = self._segment_bases(total_blocks) - dst_bases_by_chunk = [ - self._segment_bases(nblocks) for nblocks in chunk_block_counts - ] - for seg_idx, (src_base, nb) in enumerate(zip(src_bases, self._seg_block_bytes)): - logical_block_start = 0 - for dst, bases, nblocks in zip( - chunk_buffers, dst_bases_by_chunk, chunk_block_counts - ): - nbytes = nblocks * nb - if nbytes: - dst[bases[seg_idx] : bases[seg_idx] + nbytes].copy_( - src[ - src_base - + logical_block_start * nb : src_base - + logical_block_start * nb - + nbytes - ] - ) - logical_block_start += nblocks - - def _tmp_bytes(self, seg: torch.Tensor, nblocks: int) -> torch.Tensor: - elems = int(seg[0].numel()) * seg.element_size() - key = (str(seg.device), "uint8", elems, int(nblocks)) - cache = getattr(self._tls, "tmp", None) - if cache is None: - cache = {} - self._tls.tmp = cache - tmp = cache.get(key) - if tmp is None: - tmp = torch.empty((nblocks, elems), dtype=torch.uint8, device=seg.device) - cache[key] = tmp - return tmp - @staticmethod def _segment_bytes_matrix(seg: torch.Tensor) -> torch.Tensor: if not seg.is_contiguous(): @@ -359,115 +169,6 @@ def _segment_bytes_matrix(seg: torch.Tensor) -> torch.Tensor: return seg.reshape(seg.shape[0], -1).view(torch.uint8) # -- public API ------------------------------------------------------- - def gpu_to_host( - self, - host_buf: torch.Tensor, - block_ids: list[int], - stream: torch.cuda.Stream | None = None, - ) -> None: - """D2H: gather ``block_ids`` from the paged GPU cache into the flat - pinned ``host_buf`` (uint8, length == len(block_ids) * bytes_per_block).""" - block_ids = self._normalize_block_ids(block_ids) - self._validate_host_buf(host_buf, len(block_ids)) - if not block_ids: - return - with self._device_ctx(): - stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() - with stream_ctx: - if self.layout == "segment_indexed": - idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) - bases = self._segment_bases(len(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - mat = self._segment_bytes_matrix(seg) - tmp = self._tmp_bytes(seg, len(block_ids)) - torch.index_select(mat, 0, idx, out=tmp) - host_buf[base : base + len(block_ids) * nb].copy_( - tmp.reshape(-1), non_blocking=True - ) - return - - if self.layout == "segment": - bases = self._segment_bases(len(block_ids)) - runs = list(self._contiguous_runs(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - for logical_start, physical_start, run_len in runs: - src = self._blocks_bytes_view(seg, physical_start, run_len) - dst = base + logical_start * nb - host_buf[dst : dst + run_len * nb].copy_( - src, non_blocking=True - ) - return - - for i, bid in enumerate(block_ids): - base = i * self.bytes_per_block - for seg, off, nb in zip( - self._segments, self._seg_off, self._seg_block_bytes - ): - src = self._block_bytes_view(seg, bid) - host_buf[base + off : base + off + nb].copy_( - src, non_blocking=True - ) - - def host_to_gpu( - self, - host_buf: torch.Tensor, - block_ids: list[int], - stream: torch.cuda.Stream | None = None, - ) -> None: - """H2D: scatter the flat pinned ``host_buf`` back into the paged GPU - cache at ``block_ids`` (in-place into the real KV tensors).""" - block_ids = self._normalize_block_ids(block_ids) - self._validate_host_buf(host_buf, len(block_ids)) - if not block_ids: - return - with self._device_ctx(): - stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() - with stream_ctx: - if self.layout == "segment_indexed": - idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) - bases = self._segment_bases(len(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - mat = self._segment_bytes_matrix(seg) - tmp = self._tmp_bytes(seg, len(block_ids)) - tmp.copy_( - host_buf[base : base + len(block_ids) * nb].reshape_as(tmp), - non_blocking=True, - ) - mat.index_copy_(0, idx, tmp) - return - - if self.layout == "segment": - bases = self._segment_bases(len(block_ids)) - runs = list(self._contiguous_runs(block_ids)) - for seg, base, nb in zip( - self._segments, bases, self._seg_block_bytes - ): - for logical_start, physical_start, run_len in runs: - dst = self._blocks_bytes_view(seg, physical_start, run_len) - src = base + logical_start * nb - dst.copy_( - host_buf[src : src + run_len * nb], - non_blocking=True, - ) - return - - for i, bid in enumerate(block_ids): - base = i * self.bytes_per_block - for seg, off, nb in zip( - self._segments, self._seg_off, self._seg_block_bytes - ): - dst = self._block_bytes_view(seg, bid) - dst.copy_( - host_buf[base + off : base + off + nb], - non_blocking=True, - ) - def gpu_to_device_buffer( self, device_buf: torch.Tensor, diff --git a/atom/kv_transfer/offload/native_stitch.cpp b/atom/kv_transfer/offload/native_stitch.cpp deleted file mode 100644 index 6e76ca425c..0000000000 --- a/atom/kv_transfer/offload/native_stitch.cpp +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include - -#include -#include -#include - -void stitch_chunk_buffers( - torch::Tensor dst, - std::vector chunk_buffers, - std::vector chunk_block_counts, - std::vector seg_block_bytes) { - TORCH_CHECK(dst.device().is_cpu(), "dst must be a CPU tensor"); - TORCH_CHECK(dst.dtype() == torch::kUInt8, "dst must be uint8"); - TORCH_CHECK(dst.is_contiguous(), "dst must be contiguous"); - TORCH_CHECK( - chunk_buffers.size() == chunk_block_counts.size(), - "chunk_buffers and chunk_block_counts size mismatch"); - - const int64_t nchunks = static_cast(chunk_buffers.size()); - const int64_t nsegs = static_cast(seg_block_bytes.size()); - int64_t total_blocks = 0; - for (int64_t nblocks : chunk_block_counts) { - TORCH_CHECK(nblocks >= 0, "chunk block count must be non-negative"); - total_blocks += nblocks; - } - - std::vector dst_bases(nsegs); - int64_t acc = 0; - for (int64_t seg = 0; seg < nsegs; ++seg) { - const int64_t nb = seg_block_bytes[seg]; - TORCH_CHECK(nb >= 0, "segment byte count must be non-negative"); - dst_bases[seg] = acc; - acc += nb * total_blocks; - } - TORCH_CHECK(dst.numel() >= acc, "dst is smaller than stitched output"); - - std::vector src_ptrs(nchunks); - std::vector src_offsets(nchunks * nsegs); - for (int64_t c = 0; c < nchunks; ++c) { - const auto& src = chunk_buffers[c]; - TORCH_CHECK(src.device().is_cpu(), "chunk buffer must be a CPU tensor"); - TORCH_CHECK(src.dtype() == torch::kUInt8, "chunk buffer must be uint8"); - TORCH_CHECK(src.is_contiguous(), "chunk buffer must be contiguous"); - src_ptrs[c] = src.data_ptr(); - - int64_t src_acc = 0; - const int64_t nblocks = chunk_block_counts[c]; - for (int64_t seg = 0; seg < nsegs; ++seg) { - src_offsets[c * nsegs + seg] = src_acc; - src_acc += seg_block_bytes[seg] * nblocks; - } - TORCH_CHECK(src.numel() >= src_acc, "chunk buffer is smaller than expected"); - } - - auto* dst_ptr = dst.data_ptr(); - at::parallel_for(0, nsegs, 1, [&](int64_t begin, int64_t end) { - for (int64_t seg = begin; seg < end; ++seg) { - const int64_t nb = seg_block_bytes[seg]; - int64_t logical_block_start = 0; - for (int64_t c = 0; c < nchunks; ++c) { - const int64_t nblocks = chunk_block_counts[c]; - const int64_t nbytes = nblocks * nb; - if (nbytes > 0) { - std::memcpy( - dst_ptr + dst_bases[seg] + logical_block_start * nb, - src_ptrs[c] + src_offsets[c * nsegs + seg], - static_cast(nbytes)); - } - logical_block_start += nblocks; - } - } - }); -} - -void split_request_buffer( - torch::Tensor src, - std::vector chunk_buffers, - std::vector chunk_block_counts, - std::vector seg_block_bytes) { - TORCH_CHECK(src.device().is_cpu(), "src must be a CPU tensor"); - TORCH_CHECK(src.dtype() == torch::kUInt8, "src must be uint8"); - TORCH_CHECK(src.is_contiguous(), "src must be contiguous"); - TORCH_CHECK( - chunk_buffers.size() == chunk_block_counts.size(), - "chunk_buffers and chunk_block_counts size mismatch"); - - const int64_t nchunks = static_cast(chunk_buffers.size()); - const int64_t nsegs = static_cast(seg_block_bytes.size()); - int64_t total_blocks = 0; - for (int64_t nblocks : chunk_block_counts) { - TORCH_CHECK(nblocks >= 0, "chunk block count must be non-negative"); - total_blocks += nblocks; - } - - std::vector src_bases(nsegs); - int64_t acc = 0; - for (int64_t seg = 0; seg < nsegs; ++seg) { - const int64_t nb = seg_block_bytes[seg]; - TORCH_CHECK(nb >= 0, "segment byte count must be non-negative"); - src_bases[seg] = acc; - acc += nb * total_blocks; - } - TORCH_CHECK(src.numel() >= acc, "src is smaller than split input"); - - std::vector dst_ptrs(nchunks); - std::vector dst_offsets(nchunks * nsegs); - for (int64_t c = 0; c < nchunks; ++c) { - auto& dst = chunk_buffers[c]; - TORCH_CHECK(dst.device().is_cpu(), "chunk buffer must be a CPU tensor"); - TORCH_CHECK(dst.dtype() == torch::kUInt8, "chunk buffer must be uint8"); - TORCH_CHECK(dst.is_contiguous(), "chunk buffer must be contiguous"); - dst_ptrs[c] = dst.data_ptr(); - - int64_t dst_acc = 0; - const int64_t nblocks = chunk_block_counts[c]; - for (int64_t seg = 0; seg < nsegs; ++seg) { - dst_offsets[c * nsegs + seg] = dst_acc; - dst_acc += seg_block_bytes[seg] * nblocks; - } - TORCH_CHECK(dst.numel() >= dst_acc, "chunk buffer is smaller than expected"); - } - - const auto* src_ptr = src.data_ptr(); - at::parallel_for(0, nsegs, 1, [&](int64_t begin, int64_t end) { - for (int64_t seg = begin; seg < end; ++seg) { - const int64_t nb = seg_block_bytes[seg]; - int64_t logical_block_start = 0; - for (int64_t c = 0; c < nchunks; ++c) { - const int64_t nblocks = chunk_block_counts[c]; - const int64_t nbytes = nblocks * nb; - if (nbytes > 0) { - std::memcpy( - dst_ptrs[c] + dst_offsets[c * nsegs + seg], - src_ptr + src_bases[seg] + logical_block_start * nb, - static_cast(nbytes)); - } - logical_block_start += nblocks; - } - } - }); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("stitch_chunk_buffers", &stitch_chunk_buffers); - m.def("split_request_buffer", &split_request_buffer); -} diff --git a/atom/kv_transfer/offload/native_stitch.py b/atom/kv_transfer/offload/native_stitch.py deleted file mode 100644 index 65f40d0450..0000000000 --- a/atom/kv_transfer/offload/native_stitch.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -"""Optional native CPU stitch for LMCache chunk buffers. - -This is deliberately a small C++ CPU extension, not a HIP op: current profiles -show H2D at tens of milliseconds, while Python/Torch host repacking takes more -than a second for MiniMax-M2.5 32K. -""" - -from __future__ import annotations - -from pathlib import Path - -from torch.utils.cpp_extension import load - -_EXT = None - - -def _load_ext(): - global _EXT - if _EXT is None: - src = Path(__file__).with_name("native_stitch.cpp") - _EXT = load( - name="atom_lmcache_native_stitch", - sources=[str(src)], - extra_cflags=["-O3"], - verbose=False, - ) - return _EXT - - -def load_extension() -> None: - _load_ext() - - -def stitch_chunk_buffers( - dst, chunk_buffers, chunk_block_counts, seg_block_bytes -) -> None: - _load_ext().stitch_chunk_buffers( - dst, - chunk_buffers, - [int(x) for x in chunk_block_counts], - [int(x) for x in seg_block_bytes], - ) - - -def split_request_buffer( - src, chunk_buffers, chunk_block_counts, seg_block_bytes -) -> None: - _load_ext().split_request_buffer( - src, - chunk_buffers, - [int(x) for x in chunk_block_counts], - [int(x) for x in seg_block_bytes], - ) diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index f2e69713f8..494a0997d4 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -20,7 +20,6 @@ LMCacheOffloadConnector, LMCacheOffloadConnectorScheduler, ) -from atom.kv_transfer.offload import connector as offload_connector_mod from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec from atom.kv_transfer.offload.lmcache_compat import ( ATOMLMCacheGPUConnector, @@ -63,112 +62,6 @@ def _scheduler() -> LMCacheOffloadConnectorScheduler: return sched -@pytest.mark.parametrize("layout", ["segment", "segment_indexed"]) -def test_segment_major_codec_roundtrip_noncontiguous_blocks(monkeypatch, layout): - import torch - - if not hasattr(torch, "arange"): - pytest.skip("real torch is unavailable") - - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", layout) - - original = { - "l0": SimpleNamespace( - k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), - v_cache=(torch.arange(8 * 4, dtype=torch.uint8).reshape(8, 4) + 51), - k_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 101, - v_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 151, - ), - "l1": SimpleNamespace( - k_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 201), - v_cache=(torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2) + 31), - k_scale=None, - v_scale=None, - ), - } - kv_caches = { - name: SimpleNamespace( - k_cache=layer.k_cache.clone(), - v_cache=layer.v_cache.clone(), - k_scale=layer.k_scale.clone() if layer.k_scale is not None else None, - v_scale=layer.v_scale.clone() if layer.v_scale is not None else None, - ) - for name, layer in original.items() - } - - codec = ATOMKVByteCodec(kv_caches) - block_ids = [1, 2, 4, 6, 7] - host = torch.empty(len(block_ids) * codec.bytes_per_block, dtype=torch.uint8) - - codec.gpu_to_host(host, block_ids) - expected_calls = codec.segments_per_block * (3 if layout == "segment" else 2) - assert codec.copy_calls_for_block_ids(block_ids) == expected_calls - - for layer in kv_caches.values(): - layer.k_cache.zero_() - layer.v_cache.zero_() - if layer.k_scale is not None: - layer.k_scale.zero_() - if layer.v_scale is not None: - layer.v_scale.zero_() - - codec.host_to_gpu(host, block_ids) - - for name, layer in kv_caches.items(): - src = original[name] - for bid in block_ids: - assert torch.equal(layer.k_cache[bid], src.k_cache[bid]) - assert torch.equal(layer.v_cache[bid], src.v_cache[bid]) - if layer.k_scale is not None: - assert torch.equal(layer.k_scale[bid], src.k_scale[bid]) - if layer.v_scale is not None: - assert torch.equal(layer.v_scale[bid], src.v_scale[bid]) - - -def test_segment_indexed_stitches_chunk_buffers(monkeypatch): - import torch - - if not hasattr(torch, "arange"): - pytest.skip("real torch is unavailable") - - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") - kv_caches = { - "l0": SimpleNamespace( - k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), - v_cache=(torch.arange(8 * 4, dtype=torch.uint8).reshape(8, 4) + 51), - k_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 101, - v_scale=torch.arange(8, dtype=torch.uint8).reshape(8, 1) + 151, - ), - "l1": SimpleNamespace( - k_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 201), - v_cache=(torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2) + 31), - k_scale=None, - v_scale=None, - ), - } - codec = ATOMKVByteCodec(kv_caches) - chunks = [[1, 2], [4], [6, 7]] - flat_ids = [bid for bids in chunks for bid in bids] - direct = torch.empty(len(flat_ids) * codec.bytes_per_block, dtype=torch.uint8) - codec.gpu_to_host(direct, flat_ids) - - chunk_buffers = [] - for bids in chunks: - host = torch.empty(len(bids) * codec.bytes_per_block, dtype=torch.uint8) - codec.gpu_to_host(host, bids) - chunk_buffers.append(host) - - stitched = torch.empty_like(direct) - codec.stitch_chunk_buffers(stitched, chunk_buffers, [len(b) for b in chunks]) - - assert torch.equal(stitched, direct) - - split_buffers = [torch.empty_like(buf) for buf in chunk_buffers] - codec.split_request_buffer(stitched, split_buffers, [len(b) for b in chunks]) - for actual, expected in zip(split_buffers, chunk_buffers): - assert torch.equal(actual, expected) - - def test_raw_bytes_metadata_shapes_are_block_rounded(): import torch @@ -205,13 +98,12 @@ def test_raw_bytes_metadata_rejects_unaligned_chunk_size(): ) -def test_codec_device_buffer_roundtrip_noncontiguous_blocks(monkeypatch): +def test_codec_device_buffer_roundtrip_noncontiguous_blocks(): import torch if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "block") original = { "l0": SimpleNamespace( k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), @@ -257,13 +149,12 @@ def test_codec_device_buffer_roundtrip_noncontiguous_blocks(monkeypatch): assert torch.equal(layer.v_cache[bid], src.v_cache[bid]) -def test_lmcache_connector_maps_token_ranges_to_block_ids(monkeypatch): +def test_lmcache_connector_maps_token_ranges_to_block_ids(): import torch if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "block") original = { "l0": SimpleNamespace( k_cache=torch.arange(6 * 2, dtype=torch.uint8).reshape(6, 2), @@ -317,7 +208,6 @@ def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "2") original = { "l0": SimpleNamespace( @@ -477,7 +367,6 @@ def test_lmcache_connector_fallback_staging_is_chunk_bounded(monkeypatch): if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "1") original = { "l0": SimpleNamespace( @@ -564,7 +453,6 @@ def test_lmcache_connector_release_covers_fallback_chunks(monkeypatch): if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "1") monkeypatch.setenv("OFFLOAD_RELEASE_GPU_STAGING_AFTER_TRANSFER", "1") kv_caches = { @@ -599,13 +487,12 @@ def test_lmcache_connector_release_covers_fallback_chunks(monkeypatch): assert all(slot.tensor is None for slot in state.slots) -def test_lmcache_connector_rejects_oversized_memory_obj(monkeypatch): +def test_lmcache_connector_rejects_oversized_memory_obj(): import torch if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") kv_caches = { "l0": SimpleNamespace( k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), @@ -635,7 +522,6 @@ def test_lmcache_connector_respects_staging_slot_env(monkeypatch): if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") monkeypatch.setenv("OFFLOAD_GPU_STAGING_SLOTS", "2") monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "3") kv_caches = { @@ -661,7 +547,6 @@ def test_lmcache_connector_default_staging_group_chunks_is_two(monkeypatch): if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") monkeypatch.delenv("OFFLOAD_GPU_STAGING_CHUNKS", raising=False) monkeypatch.delenv("OFFLOAD_GPU_STAGING_MAX_BYTES", raising=False) kv_caches = { @@ -679,13 +564,12 @@ def test_lmcache_connector_default_staging_group_chunks_is_two(monkeypatch): assert connector.gpu_staging_capacity_bytes == 2 * connector.gpu_staging_chunk_bytes -def test_codec_chunk_major_device_buffer_layout(monkeypatch): +def test_codec_chunk_major_device_buffer_layout(): import torch if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") original = { "l0": SimpleNamespace( k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), @@ -730,13 +614,12 @@ def test_codec_chunk_major_device_buffer_layout(monkeypatch): assert torch.equal(kv_caches["l0"].v_cache, original["l0"].v_cache) -def test_codec_chunk_major_handles_tail_and_sparse_blocks(monkeypatch): +def test_codec_chunk_major_handles_tail_and_sparse_blocks(): import torch if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") original = { "l0": SimpleNamespace( k_cache=torch.arange(6 * 2, dtype=torch.uint8).reshape(6, 2), @@ -785,13 +668,12 @@ def test_codec_chunk_major_handles_tail_and_sparse_blocks(monkeypatch): assert torch.equal(layer.k_scale[bid], src.k_scale[bid]) -def test_codec_chunk_major_rejects_duplicate_block_ids(monkeypatch): +def test_codec_chunk_major_rejects_duplicate_block_ids(): import torch if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") kv_caches = { "l0": SimpleNamespace( k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), @@ -807,102 +689,6 @@ def test_codec_chunk_major_rejects_duplicate_block_ids(monkeypatch): codec.gpu_to_chunk_major_device_buffer(device_buf, [[0, 1], [1]]) -@pytest.mark.parametrize("layout", ["block", "segment", "segment_indexed"]) -@pytest.mark.parametrize("method_name", ["gpu_to_host", "host_to_gpu"]) -def test_codec_rejects_invalid_block_ids_before_copy(monkeypatch, layout, method_name): - import torch - - if not hasattr(torch, "arange"): - pytest.skip("real torch is unavailable") - - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", layout) - kv_caches = { - "l0": SimpleNamespace( - k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), - v_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), - k_scale=None, - v_scale=None, - ), - } - codec = ATOMKVByteCodec(kv_caches) - host = torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) - method = getattr(codec, method_name) - - with pytest.raises(ValueError, match="block id out of range"): - method(host, [0, 4]) - - with pytest.raises(ValueError, match="block id out of range"): - method(host, [-1]) - - -def test_codec_rejects_short_host_buffer(monkeypatch): - import torch - - if not hasattr(torch, "arange"): - pytest.skip("real torch is unavailable") - - monkeypatch.setenv("OFFLOAD_CODEC_LAYOUT", "segment_indexed") - kv_caches = { - "l0": SimpleNamespace( - k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), - v_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), - k_scale=None, - v_scale=None, - ), - } - codec = ATOMKVByteCodec(kv_caches) - host = torch.empty(codec.bytes_per_block - 1, dtype=torch.uint8) - - with pytest.raises(ValueError, match="host_buf is too small"): - codec.gpu_to_host(host, [0]) - - -def test_copy_stream_is_cached_per_codec_device(monkeypatch): - import torch - - if not hasattr(torch, "device") or not hasattr(torch, "cuda"): - pytest.skip("torch cuda API is unavailable") - - conn = LMCacheOffloadConnector.__new__(LMCacheOffloadConnector) - conn._tls = threading.local() - conn._codec = SimpleNamespace(device=torch.device("cuda:1")) - active_devices = [] - created_on = [] - - class _FakeDeviceCtx: - def __init__(self, device) -> None: - self.device = str(device) - - def __enter__(self): - active_devices.append(self.device) - return None - - def __exit__(self, *args): - active_devices.pop() - return False - - class _FakeStream: - def __init__(self) -> None: - created_on.append(active_devices[-1] if active_devices else "default") - - monkeypatch.setattr( - offload_connector_mod.torch.cuda, - "device", - lambda device: _FakeDeviceCtx(device), - ) - monkeypatch.setattr(offload_connector_mod.torch.cuda, "Stream", _FakeStream) - - rank1_stream = conn._stream() - assert conn._stream() is rank1_stream - assert created_on == ["cuda:1"] - - conn._codec = SimpleNamespace(device=torch.device("cuda:0")) - rank0_stream = conn._stream() - - assert rank0_stream is not rank1_stream - assert created_on == ["cuda:1", "cuda:0"] - - def test_full_prompt_hit_is_clamped_before_load_spec(): sched = _scheduler() sched._lookup_client = _LookupClient(hit=8) From 28dc7dfeaf4093802267dae1f65c2420b97c8ed0 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Wed, 3 Jun 2026 22:52:43 -0500 Subject: [PATCH 16/27] Remove unused offload staging sizing hooks --- atom/model_ops/attentions/aiter_attention.py | 49 -------------------- atom/model_ops/attentions/aiter_mla.py | 8 ---- atom/model_ops/attentions/backends.py | 10 ---- 3 files changed, 67 deletions(-) diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 6784733af9..b4c6fd28ab 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -360,55 +360,6 @@ def compute_block_bytes(self) -> int: ) return block_bytes - def compute_offload_staging_block_bytes(self) -> int: - """Per exposed KVCacheTensor block copied by ATOM offload. - - ``compute_block_bytes()`` is a scheduler-block KV pool estimate. - The offload codec copies slices exposed by ``build_kv_cache_tensor``; - for the AITER MHA layout those slices are physical KV blocks. - """ - from aiter import dtypes - - runner = self.model_runner - config = runner.config - hf_config = config.hf_config - kv_dtype_size = dtypes.d_dtypes[config.kv_cache_dtype].itemsize - physical_block_size = runner.physical_block_size - - def per_layer_bytes(num_kv_heads: int) -> int: - block_bytes = ( - 2 - * num_kv_heads - * hf_config.head_dim - * physical_block_size - * kv_dtype_size - ) - if config.kv_cache_dtype == "fp8": - block_bytes += 2 * num_kv_heads * physical_block_size * 4 - return block_bytes - - if runner.is_mimo_v2(): - pattern = hf_config.hybrid_layer_pattern - num_swa_layers = sum( - 1 for i in range(hf_config.num_hidden_layers) if pattern[i] == 1 - ) - num_full_layers = hf_config.num_hidden_layers - num_swa_layers - num_draft_layers = ( - runner._get_total_num_layers() - hf_config.num_hidden_layers - ) - num_swa_layers += num_draft_layers - _swa_raw = getattr(hf_config, "swa_num_key_value_heads", 0) - swa_kv_heads = ( - _swa_raw // runner.world_size - if _swa_raw >= runner.world_size - else (1 if _swa_raw else 0) - ) - return num_full_layers * per_layer_bytes( - runner._get_num_kv_heads() - ) + num_swa_layers * per_layer_bytes(swa_kv_heads) - - return hf_config.num_hidden_layers * per_layer_bytes(runner._get_num_kv_heads()) - def allocate_kv_cache_tensors( self, num_kv_heads: int, num_draft_layers: int ) -> dict: diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 811af65b51..05735599dd 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -585,14 +585,6 @@ def compute_block_bytes(self) -> int: ) return block_bytes - def compute_offload_staging_block_bytes(self) -> int: - """Per exposed MLA KVCacheTensor block copied by ATOM offload.""" - runner = self.model_runner - config = runner.config - total_num_layers = runner._get_total_num_layers() - kv_dtype_size = dtypes.d_dtypes[config.kv_cache_dtype].itemsize - return total_num_layers * 576 * kv_dtype_size - def allocate_kv_cache_tensors( self, num_kv_heads: int, num_draft_layers: int ) -> dict: diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index 7c87abb184..36ac0daf32 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -167,16 +167,6 @@ def compute_block_bytes(self) -> int: """ return 0 - def compute_offload_staging_block_bytes(self) -> int: - """Bytes copied by ATOM KV offload for one block id. - - This must match ``ATOMKVByteCodec.bytes_per_block`` without requiring - the actual KV tensors to exist yet. The default mirrors the KV pool - block estimate; backends whose exposed ``KVCacheTensor`` block geometry - differs from the scheduler block geometry should override it. - """ - return self.compute_block_bytes() - def allocate_kv_cache_tensors( self, num_kv_heads: int, num_draft_layers: int ) -> dict[str, Any]: From 0901ea92d84e19228ddd7e491f2011feb8158a0f Mon Sep 17 00:00:00 2001 From: yihonglie Date: Wed, 3 Jun 2026 23:03:33 -0500 Subject: [PATCH 17/27] Split LMCache offload metadata and staging helpers --- atom/kv_transfer/offload/connector.py | 6 +- atom/kv_transfer/offload/lmcache_compat.py | 137 +------------------ atom/kv_transfer/offload/lmcache_metadata.py | 67 +++++++++ atom/kv_transfer/offload/lmcache_staging.py | 86 ++++++++++++ tests/test_lmcache_offload_connector.py | 6 +- 5 files changed, 164 insertions(+), 138 deletions(-) create mode 100644 atom/kv_transfer/offload/lmcache_metadata.py create mode 100644 atom/kv_transfer/offload/lmcache_staging.py diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 9a0387f696..308218be9e 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -39,10 +39,8 @@ from atom.kv_transfer.disaggregation.types import KVConnectorOutput, ReqId from atom.kv_transfer.offload import config as offcfg from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec -from atom.kv_transfer.offload.lmcache_compat import ( - ATOMLMCacheGPUConnector, - ATOMRawBytesLMCacheMetadata, -) +from atom.kv_transfer.offload.lmcache_compat import ATOMLMCacheGPUConnector +from atom.kv_transfer.offload.lmcache_metadata import ATOMRawBytesLMCacheMetadata from atom.kv_transfer.offload.metadata import ( LMCacheOffloadMetadata, LMCacheReqMeta, diff --git a/atom/kv_transfer/offload/lmcache_compat.py b/atom/kv_transfer/offload/lmcache_compat.py index fad16df602..311d920d04 100644 --- a/atom/kv_transfer/offload/lmcache_compat.py +++ b/atom/kv_transfer/offload/lmcache_compat.py @@ -13,7 +13,6 @@ from __future__ import annotations from dataclasses import dataclass -import os import threading import time from typing import Any @@ -21,141 +20,19 @@ import torch from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec +from atom.kv_transfer.offload.lmcache_staging import ( + _StagingSlot, + _ThreadTransferState, + _env_flag, + _env_int, + _env_optional_int, +) def _cdiv(a: int, b: int) -> int: return -(-int(a) // int(b)) -class ATOMRawBytesLMCacheMetadata: - """Proxy around ``LMCacheMetadata`` with ATOM raw-byte allocation shapes.""" - - def __init__( - self, - base_metadata: Any, - *, - atom_block_size: int, - bytes_per_block: int, - ) -> None: - self._atom_base_metadata = base_metadata - self.__dict__.update(vars(base_metadata)) - self.atom_block_size = int(atom_block_size) - self.atom_bytes_per_block = int(bytes_per_block) - chunk_size = int(getattr(base_metadata, "chunk_size")) - if self.atom_block_size <= 0: - raise ValueError("ATOM raw-byte metadata: atom_block_size must be > 0") - if self.atom_bytes_per_block <= 0: - raise ValueError("ATOM raw-byte metadata: bytes_per_block must be > 0") - if chunk_size % self.atom_block_size != 0: - raise ValueError( - "LMCache chunk size must be divisible by ATOM KV block size: " - f"chunk_size={chunk_size}, block_size={self.atom_block_size}" - ) - - def __getattr__(self, name: str) -> Any: - return getattr(self._atom_base_metadata, name) - - def __eq__(self, other: object) -> bool: - if isinstance(other, ATOMRawBytesLMCacheMetadata): - return ( - self._atom_base_metadata == other._atom_base_metadata - and self.atom_block_size == other.atom_block_size - and self.atom_bytes_per_block == other.atom_bytes_per_block - ) - return False - - def is_first_rank(self) -> bool: - return self._atom_base_metadata.is_first_rank() - - def get_dtypes(self) -> list[torch.dtype]: - return [torch.uint8] - - def get_shapes(self, num_tokens: int | None = None) -> list[torch.Size]: - if num_tokens is None: - num_tokens = int(self.chunk_size) - nblocks = _cdiv(int(num_tokens), self.atom_block_size) - return [torch.Size((nblocks * self.atom_bytes_per_block,))] - - def get_num_groups(self) -> int: - return 1 - - -class _NullCtx: - def __enter__(self): - return None - - def __exit__(self, *args): - return False - - -class _StagingSlot: - def __init__(self, use_cuda: bool) -> None: - self.tensor: torch.Tensor | None = None - self.ready_event = None - self.free_event = None - self.free_event_valid = False - if use_cuda: - self.ready_event = torch.cuda.Event(blocking=False) - self.free_event = torch.cuda.Event(blocking=False) - - -def _env_flag(name: str, default: str = "0") -> bool: - return os.environ.get(name, default).lower() not in ("0", "false", "no", "off") - - -def _env_int(name: str, default: int, *, min_value: int = 1) -> int: - raw = os.environ.get(name) - if raw is None: - return default - try: - value = int(raw) - except ValueError as exc: - raise ValueError(f"{name} must be an integer, got {raw!r}") from exc - if value < min_value: - raise ValueError(f"{name} must be >= {min_value}, got {value}") - return value - - -def _env_optional_int(name: str, *, min_value: int = 1) -> int | None: - raw = os.environ.get(name) - if raw is None or raw == "": - return None - try: - value = int(raw) - except ValueError as exc: - raise ValueError(f"{name} must be an integer, got {raw!r}") from exc - if value < min_value: - raise ValueError(f"{name} must be >= {min_value}, got {value}") - return value - - -class _ThreadTransferState: - def __init__( - self, - device: torch.device, - use_cuda: bool, - staging_slots: int, - ) -> None: - self.device = device - self.use_cuda = use_cuda - self.pack_stream = None - self.copy_stream = None - self.next_slot = 0 - self.host_tmp: torch.Tensor | None = None - if use_cuda: - with torch.cuda.device(device): - self.pack_stream = torch.cuda.Stream() - self.copy_stream = torch.cuda.Stream() - self.slots = [_StagingSlot(use_cuda) for _ in range(staging_slots)] - else: - self.slots = [_StagingSlot(use_cuda) for _ in range(staging_slots)] - - def stream_ctx(self, stream): - if stream is None: - return _NullCtx() - return torch.cuda.stream(stream) - - @dataclass(frozen=True) class _TransferChunk: memory_obj: Any diff --git a/atom/kv_transfer/offload/lmcache_metadata.py b/atom/kv_transfer/offload/lmcache_metadata.py new file mode 100644 index 0000000000..4788880e12 --- /dev/null +++ b/atom/kv_transfer/offload/lmcache_metadata.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""LMCache metadata wrapper for ATOM raw-byte KV offload.""" + +from __future__ import annotations + +from typing import Any + +import torch + + +def _cdiv(a: int, b: int) -> int: + return -(-int(a) // int(b)) + + +class ATOMRawBytesLMCacheMetadata: + """Proxy around ``LMCacheMetadata`` with ATOM raw-byte allocation shapes.""" + + def __init__( + self, + base_metadata: Any, + *, + atom_block_size: int, + bytes_per_block: int, + ) -> None: + self._atom_base_metadata = base_metadata + self.__dict__.update(vars(base_metadata)) + self.atom_block_size = int(atom_block_size) + self.atom_bytes_per_block = int(bytes_per_block) + chunk_size = int(getattr(base_metadata, "chunk_size")) + if self.atom_block_size <= 0: + raise ValueError("ATOM raw-byte metadata: atom_block_size must be > 0") + if self.atom_bytes_per_block <= 0: + raise ValueError("ATOM raw-byte metadata: bytes_per_block must be > 0") + if chunk_size % self.atom_block_size != 0: + raise ValueError( + "LMCache chunk size must be divisible by ATOM KV block size: " + f"chunk_size={chunk_size}, block_size={self.atom_block_size}" + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self._atom_base_metadata, name) + + def __eq__(self, other: object) -> bool: + if isinstance(other, ATOMRawBytesLMCacheMetadata): + return ( + self._atom_base_metadata == other._atom_base_metadata + and self.atom_block_size == other.atom_block_size + and self.atom_bytes_per_block == other.atom_bytes_per_block + ) + return False + + def is_first_rank(self) -> bool: + return self._atom_base_metadata.is_first_rank() + + def get_dtypes(self) -> list[torch.dtype]: + return [torch.uint8] + + def get_shapes(self, num_tokens: int | None = None) -> list[torch.Size]: + if num_tokens is None: + num_tokens = int(self.chunk_size) + nblocks = _cdiv(int(num_tokens), self.atom_block_size) + return [torch.Size((nblocks * self.atom_bytes_per_block,))] + + def get_num_groups(self) -> int: + return 1 diff --git a/atom/kv_transfer/offload/lmcache_staging.py b/atom/kv_transfer/offload/lmcache_staging.py new file mode 100644 index 0000000000..4404503cda --- /dev/null +++ b/atom/kv_transfer/offload/lmcache_staging.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Staging-buffer helpers for the ATOM LMCache GPU connector.""" + +from __future__ import annotations + +import os + +import torch + + +class _NullCtx: + def __enter__(self): + return None + + def __exit__(self, *args): + return False + + +class _StagingSlot: + def __init__(self, use_cuda: bool) -> None: + self.tensor: torch.Tensor | None = None + self.ready_event = None + self.free_event = None + self.free_event_valid = False + if use_cuda: + self.ready_event = torch.cuda.Event(blocking=False) + self.free_event = torch.cuda.Event(blocking=False) + + +def _env_flag(name: str, default: str = "0") -> bool: + return os.environ.get(name, default).lower() not in ("0", "false", "no", "off") + + +def _env_int(name: str, default: int, *, min_value: int = 1) -> int: + raw = os.environ.get(name) + if raw is None: + return default + try: + value = int(raw) + except ValueError as exc: + raise ValueError(f"{name} must be an integer, got {raw!r}") from exc + if value < min_value: + raise ValueError(f"{name} must be >= {min_value}, got {value}") + return value + + +def _env_optional_int(name: str, *, min_value: int = 1) -> int | None: + raw = os.environ.get(name) + if raw is None or raw == "": + return None + try: + value = int(raw) + except ValueError as exc: + raise ValueError(f"{name} must be an integer, got {raw!r}") from exc + if value < min_value: + raise ValueError(f"{name} must be >= {min_value}, got {value}") + return value + + +class _ThreadTransferState: + def __init__( + self, + device: torch.device, + use_cuda: bool, + staging_slots: int, + ) -> None: + self.device = device + self.use_cuda = use_cuda + self.pack_stream = None + self.copy_stream = None + self.next_slot = 0 + self.host_tmp: torch.Tensor | None = None + if use_cuda: + with torch.cuda.device(device): + self.pack_stream = torch.cuda.Stream() + self.copy_stream = torch.cuda.Stream() + self.slots = [_StagingSlot(use_cuda) for _ in range(staging_slots)] + else: + self.slots = [_StagingSlot(use_cuda) for _ in range(staging_slots)] + + def stream_ctx(self, stream): + if stream is None: + return _NullCtx() + return torch.cuda.stream(stream) diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 494a0997d4..aa65e0c92e 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -21,10 +21,8 @@ LMCacheOffloadConnectorScheduler, ) from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec -from atom.kv_transfer.offload.lmcache_compat import ( - ATOMLMCacheGPUConnector, - ATOMRawBytesLMCacheMetadata, -) +from atom.kv_transfer.offload.lmcache_compat import ATOMLMCacheGPUConnector +from atom.kv_transfer.offload.lmcache_metadata import ATOMRawBytesLMCacheMetadata from atom.model_engine.scheduler import Scheduler From 0cabc8812ab3dd8488a12001cc4888131beeb8dc Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 00:04:59 -0500 Subject: [PATCH 18/27] Merge LMCache metadata wrapper into offload metadata --- atom/kv_transfer/offload/connector.py | 2 +- atom/kv_transfer/offload/lmcache_metadata.py | 67 ------------------ atom/kv_transfer/offload/metadata.py | 74 +++++++++++++++++--- tests/test_lmcache_offload_connector.py | 2 +- 4 files changed, 68 insertions(+), 77 deletions(-) delete mode 100644 atom/kv_transfer/offload/lmcache_metadata.py diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 308218be9e..88869d92cd 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -40,8 +40,8 @@ from atom.kv_transfer.offload import config as offcfg from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec from atom.kv_transfer.offload.lmcache_compat import ATOMLMCacheGPUConnector -from atom.kv_transfer.offload.lmcache_metadata import ATOMRawBytesLMCacheMetadata from atom.kv_transfer.offload.metadata import ( + ATOMRawBytesLMCacheMetadata, LMCacheOffloadMetadata, LMCacheReqMeta, LoadSpec, diff --git a/atom/kv_transfer/offload/lmcache_metadata.py b/atom/kv_transfer/offload/lmcache_metadata.py deleted file mode 100644 index 4788880e12..0000000000 --- a/atom/kv_transfer/offload/lmcache_metadata.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -"""LMCache metadata wrapper for ATOM raw-byte KV offload.""" - -from __future__ import annotations - -from typing import Any - -import torch - - -def _cdiv(a: int, b: int) -> int: - return -(-int(a) // int(b)) - - -class ATOMRawBytesLMCacheMetadata: - """Proxy around ``LMCacheMetadata`` with ATOM raw-byte allocation shapes.""" - - def __init__( - self, - base_metadata: Any, - *, - atom_block_size: int, - bytes_per_block: int, - ) -> None: - self._atom_base_metadata = base_metadata - self.__dict__.update(vars(base_metadata)) - self.atom_block_size = int(atom_block_size) - self.atom_bytes_per_block = int(bytes_per_block) - chunk_size = int(getattr(base_metadata, "chunk_size")) - if self.atom_block_size <= 0: - raise ValueError("ATOM raw-byte metadata: atom_block_size must be > 0") - if self.atom_bytes_per_block <= 0: - raise ValueError("ATOM raw-byte metadata: bytes_per_block must be > 0") - if chunk_size % self.atom_block_size != 0: - raise ValueError( - "LMCache chunk size must be divisible by ATOM KV block size: " - f"chunk_size={chunk_size}, block_size={self.atom_block_size}" - ) - - def __getattr__(self, name: str) -> Any: - return getattr(self._atom_base_metadata, name) - - def __eq__(self, other: object) -> bool: - if isinstance(other, ATOMRawBytesLMCacheMetadata): - return ( - self._atom_base_metadata == other._atom_base_metadata - and self.atom_block_size == other.atom_block_size - and self.atom_bytes_per_block == other.atom_bytes_per_block - ) - return False - - def is_first_rank(self) -> bool: - return self._atom_base_metadata.is_first_rank() - - def get_dtypes(self) -> list[torch.dtype]: - return [torch.uint8] - - def get_shapes(self, num_tokens: int | None = None) -> list[torch.Size]: - if num_tokens is None: - num_tokens = int(self.chunk_size) - nblocks = _cdiv(int(num_tokens), self.atom_block_size) - return [torch.Size((nblocks * self.atom_bytes_per_block,))] - - def get_num_groups(self) -> int: - return 1 diff --git a/atom/kv_transfer/offload/metadata.py b/atom/kv_transfer/offload/metadata.py index 831dcc9658..7d5452b403 100644 --- a/atom/kv_transfer/offload/metadata.py +++ b/atom/kv_transfer/offload/metadata.py @@ -1,23 +1,81 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -"""Per-request transfer descriptors for the LMCache CPU/NVMe offload connector. - -Ported (type-substituted) from vLLM's ``lmcache_integration/vllm_v1_adapter.py`` -(``LoadSpec`` / ``SaveSpec`` / ``RequestTracker`` / ``ReqMeta``) onto ATOM's -``Sequence`` model. These travel from the scheduler-side connector to the -worker-side connector inside :class:`LMCacheOffloadMetadata`, which subclasses -ATOM's :class:`ConnectorMetadata` so the engine forwards it opaquely through -``process_kvconnector_output`` → ``start_load_kv``. +"""Metadata helpers for the LMCache CPU/NVMe offload connector. + +``ATOMRawBytesLMCacheMetadata`` adapts LMCache's engine metadata so MemoryObjs +are allocated as opaque uint8 buffers. The remaining classes are per-request +transfer descriptors that travel from the scheduler-side connector to the +worker-side connector inside :class:`LMCacheOffloadMetadata`. """ from __future__ import annotations from dataclasses import dataclass +from typing import Any + +import torch from atom.kv_transfer.disaggregation.types import ConnectorMetadata, ReqId +def _cdiv(a: int, b: int) -> int: + return -(-int(a) // int(b)) + + +class ATOMRawBytesLMCacheMetadata: + """Proxy around ``LMCacheMetadata`` with ATOM raw-byte allocation shapes.""" + + def __init__( + self, + base_metadata: Any, + *, + atom_block_size: int, + bytes_per_block: int, + ) -> None: + self._atom_base_metadata = base_metadata + self.__dict__.update(vars(base_metadata)) + self.atom_block_size = int(atom_block_size) + self.atom_bytes_per_block = int(bytes_per_block) + chunk_size = int(getattr(base_metadata, "chunk_size")) + if self.atom_block_size <= 0: + raise ValueError("ATOM raw-byte metadata: atom_block_size must be > 0") + if self.atom_bytes_per_block <= 0: + raise ValueError("ATOM raw-byte metadata: bytes_per_block must be > 0") + if chunk_size % self.atom_block_size != 0: + raise ValueError( + "LMCache chunk size must be divisible by ATOM KV block size: " + f"chunk_size={chunk_size}, block_size={self.atom_block_size}" + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self._atom_base_metadata, name) + + def __eq__(self, other: object) -> bool: + if isinstance(other, ATOMRawBytesLMCacheMetadata): + return ( + self._atom_base_metadata == other._atom_base_metadata + and self.atom_block_size == other.atom_block_size + and self.atom_bytes_per_block == other.atom_bytes_per_block + ) + return False + + def is_first_rank(self) -> bool: + return self._atom_base_metadata.is_first_rank() + + def get_dtypes(self) -> list[torch.dtype]: + return [torch.uint8] + + def get_shapes(self, num_tokens: int | None = None) -> list[torch.Size]: + if num_tokens is None: + num_tokens = int(self.chunk_size) + nblocks = _cdiv(int(num_tokens), self.atom_block_size) + return [torch.Size((nblocks * self.atom_bytes_per_block,))] + + def get_num_groups(self) -> int: + return 1 + + @dataclass class LoadSpec: """How many tokens to load for a request, split HBM-cached vs LMCache-cached.""" diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index aa65e0c92e..268218efd8 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -22,7 +22,7 @@ ) from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec from atom.kv_transfer.offload.lmcache_compat import ATOMLMCacheGPUConnector -from atom.kv_transfer.offload.lmcache_metadata import ATOMRawBytesLMCacheMetadata +from atom.kv_transfer.offload.metadata import ATOMRawBytesLMCacheMetadata from atom.model_engine.scheduler import Scheduler From 86073f421fc091e363ea1dc0af4455bcc1eb4c5b Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 00:17:09 -0500 Subject: [PATCH 19/27] Replace native KV staging with Triton kernel --- atom/kv_transfer/offload/gpu_connector.py | 37 ++- atom/kv_transfer/offload/lmcache_compat.py | 2 +- .../kv_transfer/offload/native_kv_staging.cpp | 222 -------------- atom/kv_transfer/offload/native_kv_staging.py | 65 ----- .../offload/native_kv_staging_kernel.hip | 145 ---------- atom/kv_transfer/offload/triton_kv_staging.py | 273 ++++++++++++++++++ 6 files changed, 296 insertions(+), 448 deletions(-) delete mode 100644 atom/kv_transfer/offload/native_kv_staging.cpp delete mode 100644 atom/kv_transfer/offload/native_kv_staging.py delete mode 100644 atom/kv_transfer/offload/native_kv_staging_kernel.hip create mode 100644 atom/kv_transfer/offload/triton_kv_staging.py diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 8cb50400e2..78bbf69327 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -78,18 +78,25 @@ def __init__(self, kv_caches: dict) -> None: int(t[0].numel()) * t.element_size() for t in self._segments ] self.bytes_per_block: int = sum(self._seg_block_bytes) - self._native_kv_staging = None - if self._device.type == "cuda" and os.environ.get( - "OFFLOAD_NATIVE_KV_STAGING", "0" - ).lower() not in ("0", "false", "no", "off"): + self._fused_kv_staging = None + fused_env = os.environ.get( + "OFFLOAD_FUSED_KV_STAGING", + os.environ.get("OFFLOAD_NATIVE_KV_STAGING", "0"), + ) + if self._device.type == "cuda" and fused_env.lower() not in ( + "0", + "false", + "no", + "off", + ): try: - from atom.kv_transfer.offload import native_kv_staging + from atom.kv_transfer.offload import triton_kv_staging - native_kv_staging.load_extension() - self._native_kv_staging = native_kv_staging + triton_kv_staging.load_extension() + self._fused_kv_staging = triton_kv_staging except Exception: logger.warning( - "ATOMKVByteCodec: native KV staging unavailable; " + "ATOMKVByteCodec: Triton KV staging unavailable; " "using chunk fallback", exc_info=True, ) @@ -99,8 +106,8 @@ def device(self) -> torch.device: return self._device @property - def has_native_chunk_major_staging(self) -> bool: - return self._native_kv_staging is not None + def has_fused_chunk_major_staging(self) -> bool: + return self._fused_kv_staging is not None # -- helpers ---------------------------------------------------------- def _segment_bases(self, nblocks: int) -> list[int]: @@ -231,7 +238,7 @@ def gpu_to_chunk_major_device_buffer( Layout is MemoryObj-compatible: ``[chunk0: seg0 blocks | seg1 blocks | ...][chunk1: ...]``. - Native fused staging is used when available; otherwise this method + Fused Triton staging is used when available; otherwise this method provides a reference implementation for tests and CPU fallback. """ groups, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups( @@ -244,8 +251,8 @@ def gpu_to_chunk_major_device_buffer( with self._device_ctx(): stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() with stream_ctx: - if self._native_kv_staging is not None: - self._native_kv_staging.fused_pack_chunk_major( + if self._fused_kv_staging is not None: + self._fused_kv_staging.fused_pack_chunk_major( self._segments, self._seg_block_bytes, chunk_block_counts, @@ -282,8 +289,8 @@ def chunk_major_device_buffer_to_gpu( with self._device_ctx(): stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() with stream_ctx: - if self._native_kv_staging is not None: - self._native_kv_staging.fused_unpack_chunk_major( + if self._fused_kv_staging is not None: + self._fused_kv_staging.fused_unpack_chunk_major( device_buf, self._segments, self._seg_block_bytes, diff --git a/atom/kv_transfer/offload/lmcache_compat.py b/atom/kv_transfer/offload/lmcache_compat.py index 311d920d04..311f181e62 100644 --- a/atom/kv_transfer/offload/lmcache_compat.py +++ b/atom/kv_transfer/offload/lmcache_compat.py @@ -254,7 +254,7 @@ def _ensure_host_tmp( return state.host_tmp[:nbytes] def _can_use_fused_chunk_major(self) -> bool: - return self._use_cuda() and self.codec.has_native_chunk_major_staging + return self._use_cuda() and self.codec.has_fused_chunk_major_staging def _memory_tensor(self, memory_obj: Any, nbytes: int) -> torch.Tensor: tensor = getattr(memory_obj, "tensor", None) diff --git a/atom/kv_transfer/offload/native_kv_staging.cpp b/atom/kv_transfer/offload/native_kv_staging.cpp deleted file mode 100644 index 9dd558354d..0000000000 --- a/atom/kv_transfer/offload/native_kv_staging.cpp +++ /dev/null @@ -1,222 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include -#include - -#define ATOM_HIP_CHECK(cmd) \ - do { \ - hipError_t err = (cmd); \ - TORCH_CHECK( \ - err == hipSuccess, "HIP error: ", hipGetErrorString(err)); \ - } while (0) - -hipError_t launch_fused_pack_chunk_major( - uint8_t* device_buf, - const int64_t* segment_ptrs, - const int64_t* segment_block_bytes, - const int64_t* segment_prefix_bytes, - const int64_t* chunk_block_counts, - const int64_t* chunk_block_offsets, - const int64_t* chunk_output_bases, - const int64_t* block_ids, - int64_t num_chunks, - int64_t num_segments, - hipStream_t stream); - -hipError_t launch_fused_unpack_chunk_major( - const uint8_t* device_buf, - const int64_t* segment_ptrs, - const int64_t* segment_block_bytes, - const int64_t* segment_prefix_bytes, - const int64_t* chunk_block_counts, - const int64_t* chunk_block_offsets, - const int64_t* chunk_output_bases, - const int64_t* block_ids, - int64_t num_chunks, - int64_t num_segments, - hipStream_t stream); - -namespace { - -torch::Tensor make_device_i64( - const std::vector& values, - const torch::Device& device, - hipStream_t stream) { - auto tensor = torch::empty( - {static_cast(values.size())}, - torch::TensorOptions().dtype(torch::kInt64).device(device)); - if (!values.empty()) { - ATOM_HIP_CHECK(hipMemcpyAsync( - tensor.data_ptr(), - values.data(), - values.size() * sizeof(int64_t), - hipMemcpyHostToDevice, - stream)); - } - return tensor; -} - -struct StagingMeta { - torch::Tensor segment_ptrs; - torch::Tensor segment_block_bytes; - torch::Tensor segment_prefix_bytes; - torch::Tensor chunk_block_counts; - torch::Tensor chunk_block_offsets; - torch::Tensor chunk_output_bases; - torch::Tensor block_ids; - int64_t num_chunks = 0; - int64_t num_segments = 0; - int64_t total_bytes = 0; -}; - -StagingMeta build_meta( - const std::vector& segment_tensors, - const std::vector& segment_block_bytes, - const std::vector& chunk_block_counts, - const std::vector& block_ids, - torch::Tensor device_buf, - hipStream_t stream) { - TORCH_CHECK(device_buf.is_cuda(), "device_buf must be a CUDA/HIP tensor"); - TORCH_CHECK(device_buf.dtype() == torch::kUInt8, "device_buf must be uint8"); - TORCH_CHECK(device_buf.is_contiguous(), "device_buf must be contiguous"); - TORCH_CHECK( - segment_tensors.size() == segment_block_bytes.size(), - "segment_tensors and segment_block_bytes size mismatch"); - - const int64_t num_segments = static_cast(segment_tensors.size()); - const int64_t num_chunks = static_cast(chunk_block_counts.size()); - TORCH_CHECK(num_segments > 0, "at least one segment is required"); - - std::vector segment_ptr_values(num_segments); - std::vector segment_prefix_values(num_segments); - int64_t bytes_per_block = 0; - for (int64_t i = 0; i < num_segments; ++i) { - const auto& seg = segment_tensors[i]; - TORCH_CHECK(seg.is_cuda(), "segment tensor must be CUDA/HIP"); - TORCH_CHECK(seg.device() == device_buf.device(), "segment/device mismatch"); - TORCH_CHECK(seg.is_contiguous(), "segment tensor must be contiguous"); - TORCH_CHECK(segment_block_bytes[i] > 0, "segment block bytes must be > 0"); - segment_ptr_values[i] = - reinterpret_cast(static_cast(seg.data_ptr())); - segment_prefix_values[i] = bytes_per_block; - bytes_per_block += segment_block_bytes[i]; - } - - std::vector chunk_block_offsets(num_chunks); - std::vector chunk_output_bases(num_chunks); - int64_t block_offset = 0; - int64_t byte_offset = 0; - for (int64_t c = 0; c < num_chunks; ++c) { - const int64_t nblocks = chunk_block_counts[c]; - TORCH_CHECK(nblocks >= 0, "chunk block count must be non-negative"); - chunk_block_offsets[c] = block_offset; - chunk_output_bases[c] = byte_offset; - block_offset += nblocks; - byte_offset += nblocks * bytes_per_block; - } - TORCH_CHECK( - static_cast(block_ids.size()) == block_offset, - "block_ids length does not match chunk block counts"); - TORCH_CHECK( - device_buf.numel() >= byte_offset, - "device_buf is smaller than chunk-major staging output"); - - StagingMeta meta; - meta.segment_ptrs = make_device_i64(segment_ptr_values, device_buf.device(), stream); - meta.segment_block_bytes = - make_device_i64(segment_block_bytes, device_buf.device(), stream); - meta.segment_prefix_bytes = - make_device_i64(segment_prefix_values, device_buf.device(), stream); - meta.chunk_block_counts = - make_device_i64(chunk_block_counts, device_buf.device(), stream); - meta.chunk_block_offsets = - make_device_i64(chunk_block_offsets, device_buf.device(), stream); - meta.chunk_output_bases = - make_device_i64(chunk_output_bases, device_buf.device(), stream); - meta.block_ids = make_device_i64(block_ids, device_buf.device(), stream); - meta.num_chunks = num_chunks; - meta.num_segments = num_segments; - meta.total_bytes = byte_offset; - return meta; -} - -hipStream_t current_hip_stream(torch::Device device) { - c10::hip::HIPGuardMasqueradingAsCUDA guard(device); - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(device.index()); - return stream.stream(); -} - -} // namespace - -void fused_pack_chunk_major( - std::vector segment_tensors, - std::vector segment_block_bytes, - std::vector chunk_block_counts, - std::vector block_ids, - torch::Tensor device_buf) { - auto stream = current_hip_stream(device_buf.device()); - auto meta = build_meta( - segment_tensors, - segment_block_bytes, - chunk_block_counts, - block_ids, - device_buf, - stream); - if (meta.total_bytes == 0) { - return; - } - ATOM_HIP_CHECK(launch_fused_pack_chunk_major( - device_buf.data_ptr(), - meta.segment_ptrs.data_ptr(), - meta.segment_block_bytes.data_ptr(), - meta.segment_prefix_bytes.data_ptr(), - meta.chunk_block_counts.data_ptr(), - meta.chunk_block_offsets.data_ptr(), - meta.chunk_output_bases.data_ptr(), - meta.block_ids.data_ptr(), - meta.num_chunks, - meta.num_segments, - stream)); -} - -void fused_unpack_chunk_major( - torch::Tensor device_buf, - std::vector segment_tensors, - std::vector segment_block_bytes, - std::vector chunk_block_counts, - std::vector block_ids) { - auto stream = current_hip_stream(device_buf.device()); - auto meta = build_meta( - segment_tensors, - segment_block_bytes, - chunk_block_counts, - block_ids, - device_buf, - stream); - if (meta.total_bytes == 0) { - return; - } - ATOM_HIP_CHECK(launch_fused_unpack_chunk_major( - device_buf.data_ptr(), - meta.segment_ptrs.data_ptr(), - meta.segment_block_bytes.data_ptr(), - meta.segment_prefix_bytes.data_ptr(), - meta.chunk_block_counts.data_ptr(), - meta.chunk_block_offsets.data_ptr(), - meta.chunk_output_bases.data_ptr(), - meta.block_ids.data_ptr(), - meta.num_chunks, - meta.num_segments, - stream)); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fused_pack_chunk_major", &fused_pack_chunk_major); - m.def("fused_unpack_chunk_major", &fused_unpack_chunk_major); -} diff --git a/atom/kv_transfer/offload/native_kv_staging.py b/atom/kv_transfer/offload/native_kv_staging.py deleted file mode 100644 index 73b33fdf5b..0000000000 --- a/atom/kv_transfer/offload/native_kv_staging.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -"""Optional HIP fused chunk-major staging for ATOM LMCache offload.""" - -from __future__ import annotations - -from pathlib import Path - -from torch.utils.cpp_extension import load - -_EXT = None - - -def _load_ext(): - global _EXT - if _EXT is None: - base = Path(__file__).parent - _EXT = load( - name="atom_lmcache_native_kv_staging", - sources=[ - str(base / "native_kv_staging.cpp"), - str(base / "native_kv_staging_kernel.hip"), - ], - extra_cflags=["-O3"], - extra_cuda_cflags=["-O3"], - verbose=False, - ) - return _EXT - - -def load_extension() -> None: - _load_ext() - - -def fused_pack_chunk_major( - segment_tensors, - segment_block_bytes, - chunk_block_counts, - block_ids, - device_buf, -) -> None: - _load_ext().fused_pack_chunk_major( - segment_tensors, - [int(x) for x in segment_block_bytes], - [int(x) for x in chunk_block_counts], - [int(x) for x in block_ids], - device_buf, - ) - - -def fused_unpack_chunk_major( - device_buf, - segment_tensors, - segment_block_bytes, - chunk_block_counts, - block_ids, -) -> None: - _load_ext().fused_unpack_chunk_major( - device_buf, - segment_tensors, - [int(x) for x in segment_block_bytes], - [int(x) for x in chunk_block_counts], - [int(x) for x in block_ids], - ) diff --git a/atom/kv_transfer/offload/native_kv_staging_kernel.hip b/atom/kv_transfer/offload/native_kv_staging_kernel.hip deleted file mode 100644 index 652276be3e..0000000000 --- a/atom/kv_transfer/offload/native_kv_staging_kernel.hip +++ /dev/null @@ -1,145 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include - -namespace { - -constexpr int kThreads = 256; - -__global__ void pack_chunk_major_kernel( - uint8_t* __restrict__ device_buf, - const int64_t* __restrict__ segment_ptrs, - const int64_t* __restrict__ segment_block_bytes, - const int64_t* __restrict__ segment_prefix_bytes, - const int64_t* __restrict__ chunk_block_counts, - const int64_t* __restrict__ chunk_block_offsets, - const int64_t* __restrict__ chunk_output_bases, - const int64_t* __restrict__ block_ids, - int64_t num_segments) { - const int64_t job = static_cast(blockIdx.x); - const int64_t chunk_id = job / num_segments; - const int64_t seg_id = job - chunk_id * num_segments; - const int64_t nblocks = chunk_block_counts[chunk_id]; - const int64_t seg_bytes = segment_block_bytes[seg_id]; - const int64_t nbytes = nblocks * seg_bytes; - if (nbytes <= 0) { - return; - } - - const auto* src_base = - reinterpret_cast(segment_ptrs[seg_id]); - uint8_t* dst_base = device_buf + chunk_output_bases[chunk_id] + - segment_prefix_bytes[seg_id] * nblocks; - const int64_t block_offset = chunk_block_offsets[chunk_id]; - - for (int64_t i = static_cast(threadIdx.x); i < nbytes; - i += static_cast(blockDim.x)) { - const int64_t local_block = i / seg_bytes; - const int64_t byte_in_block = i - local_block * seg_bytes; - const int64_t physical_block = block_ids[block_offset + local_block]; - dst_base[i] = src_base[physical_block * seg_bytes + byte_in_block]; - } -} - -__global__ void unpack_chunk_major_kernel( - const uint8_t* __restrict__ device_buf, - const int64_t* __restrict__ segment_ptrs, - const int64_t* __restrict__ segment_block_bytes, - const int64_t* __restrict__ segment_prefix_bytes, - const int64_t* __restrict__ chunk_block_counts, - const int64_t* __restrict__ chunk_block_offsets, - const int64_t* __restrict__ chunk_output_bases, - const int64_t* __restrict__ block_ids, - int64_t num_segments) { - const int64_t job = static_cast(blockIdx.x); - const int64_t chunk_id = job / num_segments; - const int64_t seg_id = job - chunk_id * num_segments; - const int64_t nblocks = chunk_block_counts[chunk_id]; - const int64_t seg_bytes = segment_block_bytes[seg_id]; - const int64_t nbytes = nblocks * seg_bytes; - if (nbytes <= 0) { - return; - } - - const auto* src_base = device_buf + chunk_output_bases[chunk_id] + - segment_prefix_bytes[seg_id] * nblocks; - auto* dst_base = reinterpret_cast(segment_ptrs[seg_id]); - const int64_t block_offset = chunk_block_offsets[chunk_id]; - - for (int64_t i = static_cast(threadIdx.x); i < nbytes; - i += static_cast(blockDim.x)) { - const int64_t local_block = i / seg_bytes; - const int64_t byte_in_block = i - local_block * seg_bytes; - const int64_t physical_block = block_ids[block_offset + local_block]; - dst_base[physical_block * seg_bytes + byte_in_block] = src_base[i]; - } -} - -} // namespace - -hipError_t launch_fused_pack_chunk_major( - uint8_t* device_buf, - const int64_t* segment_ptrs, - const int64_t* segment_block_bytes, - const int64_t* segment_prefix_bytes, - const int64_t* chunk_block_counts, - const int64_t* chunk_block_offsets, - const int64_t* chunk_output_bases, - const int64_t* block_ids, - int64_t num_chunks, - int64_t num_segments, - hipStream_t stream) { - const dim3 grid(static_cast(num_chunks * num_segments)); - const dim3 block(kThreads); - hipLaunchKernelGGL( - pack_chunk_major_kernel, - grid, - block, - 0, - stream, - device_buf, - segment_ptrs, - segment_block_bytes, - segment_prefix_bytes, - chunk_block_counts, - chunk_block_offsets, - chunk_output_bases, - block_ids, - num_segments); - return hipGetLastError(); -} - -hipError_t launch_fused_unpack_chunk_major( - const uint8_t* device_buf, - const int64_t* segment_ptrs, - const int64_t* segment_block_bytes, - const int64_t* segment_prefix_bytes, - const int64_t* chunk_block_counts, - const int64_t* chunk_block_offsets, - const int64_t* chunk_output_bases, - const int64_t* block_ids, - int64_t num_chunks, - int64_t num_segments, - hipStream_t stream) { - const dim3 grid(static_cast(num_chunks * num_segments)); - const dim3 block(kThreads); - hipLaunchKernelGGL( - unpack_chunk_major_kernel, - grid, - block, - 0, - stream, - device_buf, - segment_ptrs, - segment_block_bytes, - segment_prefix_bytes, - chunk_block_counts, - chunk_block_offsets, - chunk_output_bases, - block_ids, - num_segments); - return hipGetLastError(); -} diff --git a/atom/kv_transfer/offload/triton_kv_staging.py b/atom/kv_transfer/offload/triton_kv_staging.py new file mode 100644 index 0000000000..384a56c182 --- /dev/null +++ b/atom/kv_transfer/offload/triton_kv_staging.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Triton fused chunk-major staging for ATOM LMCache offload.""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +_BLOCK_BYTES = 1024 + + +@triton.jit +def _pack_chunk_major_kernel( + device_buf, + segment_ptrs, + segment_block_bytes, + segment_prefix_bytes, + chunk_block_counts, + chunk_block_offsets, + chunk_output_bases, + block_ids, + NUM_SEGMENTS: tl.constexpr, + BLOCK_BYTES: tl.constexpr, +): + job = tl.program_id(0) + tile = tl.program_id(1) + chunk_id = job // NUM_SEGMENTS + seg_id = job - chunk_id * NUM_SEGMENTS + + nblocks = tl.load(chunk_block_counts + chunk_id).to(tl.int64) + seg_bytes = tl.load(segment_block_bytes + seg_id).to(tl.int64) + nbytes = nblocks * seg_bytes + offsets = tile.to(tl.int64) * BLOCK_BYTES + tl.arange(0, BLOCK_BYTES).to(tl.int64) + mask = offsets < nbytes + + local_block = offsets // seg_bytes + byte_in_block = offsets - local_block * seg_bytes + block_offset = tl.load(chunk_block_offsets + chunk_id).to(tl.int64) + physical_block = tl.load( + block_ids + block_offset + local_block, + mask=mask, + other=0, + ).to(tl.int64) + + seg_addr = tl.load(segment_ptrs + seg_id) + src = (seg_addr + physical_block * seg_bytes + byte_in_block).to( + tl.pointer_type(tl.uint8) + ) + dst = ( + device_buf + + tl.load(chunk_output_bases + chunk_id).to(tl.int64) + + tl.load(segment_prefix_bytes + seg_id).to(tl.int64) * nblocks + + offsets + ) + data = tl.load(src, mask=mask) + tl.store(dst, data, mask=mask) + + +@triton.jit +def _unpack_chunk_major_kernel( + device_buf, + segment_ptrs, + segment_block_bytes, + segment_prefix_bytes, + chunk_block_counts, + chunk_block_offsets, + chunk_output_bases, + block_ids, + NUM_SEGMENTS: tl.constexpr, + BLOCK_BYTES: tl.constexpr, +): + job = tl.program_id(0) + tile = tl.program_id(1) + chunk_id = job // NUM_SEGMENTS + seg_id = job - chunk_id * NUM_SEGMENTS + + nblocks = tl.load(chunk_block_counts + chunk_id).to(tl.int64) + seg_bytes = tl.load(segment_block_bytes + seg_id).to(tl.int64) + nbytes = nblocks * seg_bytes + offsets = tile.to(tl.int64) * BLOCK_BYTES + tl.arange(0, BLOCK_BYTES).to(tl.int64) + mask = offsets < nbytes + + local_block = offsets // seg_bytes + byte_in_block = offsets - local_block * seg_bytes + block_offset = tl.load(chunk_block_offsets + chunk_id).to(tl.int64) + physical_block = tl.load( + block_ids + block_offset + local_block, + mask=mask, + other=0, + ).to(tl.int64) + + src = ( + device_buf + + tl.load(chunk_output_bases + chunk_id).to(tl.int64) + + tl.load(segment_prefix_bytes + seg_id).to(tl.int64) * nblocks + + offsets + ) + seg_addr = tl.load(segment_ptrs + seg_id) + dst = (seg_addr + physical_block * seg_bytes + byte_in_block).to( + tl.pointer_type(tl.uint8) + ) + data = tl.load(src, mask=mask) + tl.store(dst, data, mask=mask) + + +def load_extension() -> None: + """Compatibility hook matching the old native module API.""" + return None + + +def _device_i64(values: list[int], device: torch.device) -> torch.Tensor: + return torch.tensor(values, dtype=torch.int64, device=device) + + +def _build_meta( + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, + device_buf: torch.Tensor, +) -> tuple[torch.Tensor, ...]: + if not device_buf.is_cuda: + raise ValueError("device_buf must be a CUDA/HIP tensor") + if device_buf.dtype != torch.uint8: + raise TypeError("device_buf must be uint8") + if not device_buf.is_contiguous(): + raise ValueError("device_buf must be contiguous") + if len(segment_tensors) != len(segment_block_bytes): + raise ValueError("segment_tensors and segment_block_bytes size mismatch") + if not segment_tensors: + raise ValueError("at least one segment is required") + + device = device_buf.device + segment_ptr_values: list[int] = [] + segment_prefix_values: list[int] = [] + bytes_per_block = 0 + for seg, nb in zip(segment_tensors, segment_block_bytes, strict=True): + if not seg.is_cuda: + raise ValueError("segment tensor must be CUDA/HIP") + if seg.device != device: + raise ValueError("segment/device mismatch") + if not seg.is_contiguous(): + raise ValueError("segment tensor must be contiguous") + nb = int(nb) + if nb <= 0: + raise ValueError("segment block bytes must be > 0") + segment_ptr_values.append(int(seg.data_ptr())) + segment_prefix_values.append(bytes_per_block) + bytes_per_block += nb + + chunk_block_offsets: list[int] = [] + chunk_output_bases: list[int] = [] + block_offset = 0 + byte_offset = 0 + max_tile_nbytes = 0 + max_seg_bytes = max(int(nb) for nb in segment_block_bytes) + for nblocks in chunk_block_counts: + nblocks = int(nblocks) + if nblocks < 0: + raise ValueError("chunk block count must be non-negative") + chunk_block_offsets.append(block_offset) + chunk_output_bases.append(byte_offset) + block_offset += nblocks + byte_offset += nblocks * bytes_per_block + max_tile_nbytes = max(max_tile_nbytes, nblocks * max_seg_bytes) + + if len(block_ids) != block_offset: + raise ValueError("block_ids length does not match chunk block counts") + if int(device_buf.numel()) < byte_offset: + raise ValueError("device_buf is smaller than chunk-major staging output") + + return ( + _device_i64(segment_ptr_values, device), + _device_i64([int(x) for x in segment_block_bytes], device), + _device_i64(segment_prefix_values, device), + _device_i64([int(x) for x in chunk_block_counts], device), + _device_i64(chunk_block_offsets, device), + _device_i64(chunk_output_bases, device), + _device_i64([int(x) for x in block_ids], device), + torch.tensor([int(byte_offset), int(max_tile_nbytes)], dtype=torch.int64), + ) + + +def fused_pack_chunk_major( + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, + device_buf, +) -> None: + ( + segment_ptrs, + segment_block_bytes_t, + segment_prefix_bytes, + chunk_block_counts_t, + chunk_block_offsets, + chunk_output_bases, + block_ids_t, + sizes, + ) = _build_meta( + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, + device_buf, + ) + if int(sizes[0].item()) == 0: + return + grid = ( + len(chunk_block_counts) * len(segment_tensors), + triton.cdiv(int(sizes[1].item()), _BLOCK_BYTES), + ) + _pack_chunk_major_kernel[grid]( + device_buf, + segment_ptrs, + segment_block_bytes_t, + segment_prefix_bytes, + chunk_block_counts_t, + chunk_block_offsets, + chunk_output_bases, + block_ids_t, + NUM_SEGMENTS=len(segment_tensors), + BLOCK_BYTES=_BLOCK_BYTES, + num_warps=8, + ) + + +def fused_unpack_chunk_major( + device_buf, + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, +) -> None: + ( + segment_ptrs, + segment_block_bytes_t, + segment_prefix_bytes, + chunk_block_counts_t, + chunk_block_offsets, + chunk_output_bases, + block_ids_t, + sizes, + ) = _build_meta( + segment_tensors, + segment_block_bytes, + chunk_block_counts, + block_ids, + device_buf, + ) + if int(sizes[0].item()) == 0: + return + grid = ( + len(chunk_block_counts) * len(segment_tensors), + triton.cdiv(int(sizes[1].item()), _BLOCK_BYTES), + ) + _unpack_chunk_major_kernel[grid]( + device_buf, + segment_ptrs, + segment_block_bytes_t, + segment_prefix_bytes, + chunk_block_counts_t, + chunk_block_offsets, + chunk_output_bases, + block_ids_t, + NUM_SEGMENTS=len(segment_tensors), + BLOCK_BYTES=_BLOCK_BYTES, + num_warps=8, + ) From 00446512bec4805b7ffb155d7fa50b43af4bbfbe Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 00:34:09 -0500 Subject: [PATCH 20/27] Refactor scheduler remote KV admission --- atom/model_engine/scheduler.py | 412 +++++++++++++++++++-------------- 1 file changed, 239 insertions(+), 173 deletions(-) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index d3c790442e..eb21b5cfe2 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -714,105 +714,20 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: self._rejected.append(seq) continue - # KV Transfer: skip request if still waiting for remote KVs - waiting_remote_to_waiting_ready = False - if seq.status == SequenceStatus.WAITING_FOR_REMOTE_KVS: - if self._pop_req_id(self.failed_recving_kv_req_ids, seq.id): - offload_trace( - "scheduler_load_failed_wake", - req=seq.id, - cached=seq.num_cached_tokens, - prompt=seq.num_prompt_tokens, - ) - if self.kv_connector is not None and hasattr( - self.kv_connector, "load_failed" - ): - self.kv_connector.load_failed(seq.id) - seq.status = SequenceStatus.WAITING - seq.offload_loaded = False - seq.offload_loaded_tokens = seq.num_cached_tokens - seq.offload_load_failed = True - else: - waiting_remote_to_waiting_ready = ( - self._update_waiting_for_remote_kv(seq) - ) - if waiting_remote_to_waiting_ready: - seq.status = SequenceStatus.WAITING - else: - skipped_waiting_requests.append(seq) - continue - - # OFFLOAD fresh-wake: a seq whose CPU/NVMe prefix just finished - # loading into its GPU blocks. Unlike P/D (which jumps straight to - # decode with an injected first token), offload must resume as a - # PREFILL of only the un-loaded SUFFIX: bump num_cached_tokens to the - # loaded count and fall through to admission WITHOUT re-matching or - # re-allocating (blocks were allocated for the whole seq before it - # parked). `offload_resume` also covers the case where a woken - # offload seq did not fit this step and was re-parked as plain - # WAITING -- it prevents re-running get_num_new_matched_tokens and, - # critically, re-calling block_manager.allocate() onto an already - # populated block_table (005's `assert not seq.block_table` crash). - is_offload = self._is_offload_connector() - if waiting_remote_to_waiting_ready and is_offload: - loaded = getattr(seq, "offload_loaded_tokens", None) - logger.debug( - "[OFFLOAD-WAKE] seq %s: loaded=%s prev_cached=%d num_tokens=%d", - seq.id, - loaded, - seq.num_cached_tokens, - seq.num_tokens, - ) - offload_trace( - "scheduler_offload_wake", - req=seq.id, - loaded=loaded, - prev_cached=seq.num_cached_tokens, - prompt=seq.num_prompt_tokens, - ) - if loaded is not None and loaded > seq.num_cached_tokens: - seq.num_cached_tokens = loaded - seq.offload_loaded = True - waiting_remote_to_waiting_ready = False # not the P/D decode jump - - offload_resume = ( - is_offload - and ( - getattr(seq, "offload_loaded", False) - or getattr(seq, "offload_load_failed", False) - ) - and len(seq.block_table) > 0 + remote_ready_for_decode = self._resolve_waiting_remote_kv( + seq, skipped_waiting_requests ) + if remote_ready_for_decode is None: + continue - need_to_remove_to_load_kv_async_queue = False - if ( - self.kv_connector is not None - and not waiting_remote_to_waiting_ready - and not offload_resume - ): - _ext_tokens, need_to_remove_to_load_kv_async_queue = ( - self.kv_connector.get_num_new_matched_tokens(seq) - ) + offload_resume = self._is_offload_prefill_resume(seq) + needs_remote_load = self._query_connector_prefill_match( + seq, + skip=remote_ready_for_decode or offload_resume, + ) - if waiting_remote_to_waiting_ready: - seq.status = SequenceStatus.RUNNING - seq.is_first_decode = True - first_token_id = (seq.kv_transfer_params or {}).get("first_token_id") - if first_token_id is not None: - seq.append_token(first_token_id) - seq._injected_t0 = first_token_id - logger.info( - "[PD-TRANSITION] seq %s: num_tokens=%d, " - "num_prompt=%d, blocks=%d, first_token=%s, " - "last_5_tids=%s", - seq.id, - seq.num_tokens, - seq.num_prompt_tokens, - len(seq.block_table), - first_token_id, - seq.token_ids[-5:], - ) - self.running.append(seq) + if remote_ready_for_decode: + self._schedule_first_decode_after_remote_kv(seq) continue if offload_resume: @@ -820,35 +735,22 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: # the batch budget. No re-match / re-allocate / re-park. num_new_tokens = seq.num_prompt_tokens - seq.num_cached_tokens budget_remaining = self.max_num_batched_tokens - num_batched_tokens - if self.enable_chunked_prefill: - chunk = min(num_new_tokens, budget_remaining) - else: - if num_new_tokens > budget_remaining and num_batched_tokens > 0: - self.waiting.appendleft(seq) - break - chunk = num_new_tokens - - assert chunk > 0, ( - f"chunk must be positive: {chunk=}, " - f"{num_new_tokens=}, {budget_remaining=}" + chunk = self._prefill_chunk_for_budget( + num_new_tokens, budget_remaining, num_batched_tokens ) - num_seqs_prefill += 1 - if self.cache_stats: - self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) - num_batched_tokens += chunk - seq.status = SequenceStatus.RUNNING - seq.type = SequenceType.PREFILL - self.running.append(seq) - scheduled_seqs[seq.id] = seq - num_scheduled_tokens.append(chunk) - offload_trace( - "scheduler_prefill_scheduled", - req=seq.id, - new_tokens=chunk, - cached=seq.num_cached_tokens, - prompt=seq.num_prompt_tokens, - offload_loaded=getattr(seq, "offload_loaded", False), - load_failed=getattr(seq, "offload_load_failed", False), + if chunk is None: + self.waiting.appendleft(seq) + break + self._assert_positive_prefill_chunk( + chunk, num_new_tokens, budget_remaining + ) + num_seqs_prefill, num_batched_tokens = self._schedule_prefill_seq( + seq, + chunk, + scheduled_seqs, + num_scheduled_tokens, + num_seqs_prefill, + num_batched_tokens, ) continue @@ -864,13 +766,12 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: - num_cached_blocks * self.block_manager.block_size ) budget_remaining = self.max_num_batched_tokens - num_batched_tokens - if self.enable_chunked_prefill: - chunk = min(num_new_tokens, budget_remaining) - else: - if num_new_tokens > budget_remaining and num_batched_tokens > 0: - self.waiting.appendleft(seq) - break - chunk = num_new_tokens + chunk = self._prefill_chunk_for_budget( + num_new_tokens, budget_remaining, num_batched_tokens + ) + if chunk is None: + self.waiting.appendleft(seq) + break t_alloc0 = time.perf_counter() self.block_manager.allocate(seq, num_cached_blocks) offload_trace( @@ -882,53 +783,26 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: alloc_ms=f"{(time.perf_counter() - t_alloc0) * 1000:.2f}", ) - if self.kv_connector is not None: - self.kv_connector.update_state_after_alloc(seq) + self._notify_connector_after_prefill_alloc(seq) - if need_to_remove_to_load_kv_async_queue: - if hasattr(self.kv_connector, "should_park_for_load_after_alloc"): - need_to_remove_to_load_kv_async_queue = ( - self.kv_connector.should_park_for_load_after_alloc(seq) - ) + needs_remote_load = self._confirm_remote_load_after_alloc( + seq, needs_remote_load + ) - if need_to_remove_to_load_kv_async_queue: - offload_trace( - "scheduler_park_for_load", - req=seq.id, - cached=seq.num_cached_tokens, - prompt=seq.num_prompt_tokens, - blocks=len(seq.block_table), - ) - skipped_waiting_requests.append(seq) - seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS + if needs_remote_load: + self._park_for_remote_load(seq, skipped_waiting_requests) continue - if self.kv_connector is not None and hasattr( - self.kv_connector, "adjust_prefill_chunk_after_alloc" - ): - chunk = self.kv_connector.adjust_prefill_chunk_after_alloc(seq, chunk) + chunk = self._adjust_prefill_chunk_after_alloc(seq, chunk) - assert chunk > 0, ( - f"chunk must be positive: {chunk=}, " - f"{num_new_tokens=}, {budget_remaining=}" - ) - num_seqs_prefill += 1 - if self.cache_stats: - self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) - num_batched_tokens += chunk - seq.status = SequenceStatus.RUNNING - seq.type = SequenceType.PREFILL - self.running.append(seq) - scheduled_seqs[seq.id] = seq - num_scheduled_tokens.append(chunk) - offload_trace( - "scheduler_prefill_scheduled", - req=seq.id, - new_tokens=chunk, - cached=seq.num_cached_tokens, - prompt=seq.num_prompt_tokens, - offload_loaded=getattr(seq, "offload_loaded", False), - load_failed=getattr(seq, "offload_load_failed", False), + self._assert_positive_prefill_chunk(chunk, num_new_tokens, budget_remaining) + num_seqs_prefill, num_batched_tokens = self._schedule_prefill_seq( + seq, + chunk, + scheduled_seqs, + num_scheduled_tokens, + num_seqs_prefill, + num_batched_tokens, ) if skipped_waiting_requests: @@ -1054,6 +928,198 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: ) return (decode_batch, scheduled_seqs) + # -- Remote KV / offload admission helpers ------------------------------ + def _resolve_waiting_remote_kv( + self, seq: Sequence, skipped_waiting_requests: deque[Sequence] + ) -> Optional[bool]: + """Resolve a ``WAITING_FOR_REMOTE_KVS`` request before admission. + + Returns: + - ``None`` when the request is still blocked and has been requeued. + - ``True`` when a P/D consumer should jump to first decode. + - ``False`` when normal prefill admission should continue. + """ + if seq.status != SequenceStatus.WAITING_FOR_REMOTE_KVS: + return False + + if self._consume_failed_remote_kv(seq): + return False + + if not self._update_waiting_for_remote_kv(seq): + skipped_waiting_requests.append(seq) + return None + + seq.status = SequenceStatus.WAITING + if self._is_offload_connector(): + self._mark_offload_load_ready(seq) + return False + return True + + def _consume_failed_remote_kv(self, seq: Sequence) -> bool: + if not self._pop_req_id(self.failed_recving_kv_req_ids, seq.id): + return False + + offload_trace( + "scheduler_load_failed_wake", + req=seq.id, + cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + ) + if self.kv_connector is not None and hasattr(self.kv_connector, "load_failed"): + self.kv_connector.load_failed(seq.id) + seq.status = SequenceStatus.WAITING + seq.offload_loaded = False + seq.offload_loaded_tokens = seq.num_cached_tokens + seq.offload_load_failed = True + return True + + def _mark_offload_load_ready(self, seq: Sequence) -> None: + """Turn a completed offload load into a suffix-prefill resume.""" + loaded = getattr(seq, "offload_loaded_tokens", None) + logger.debug( + "[OFFLOAD-WAKE] seq %s: loaded=%s prev_cached=%d num_tokens=%d", + seq.id, + loaded, + seq.num_cached_tokens, + seq.num_tokens, + ) + offload_trace( + "scheduler_offload_wake", + req=seq.id, + loaded=loaded, + prev_cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + ) + if loaded is not None and loaded > seq.num_cached_tokens: + seq.num_cached_tokens = loaded + seq.offload_loaded = True + + def _is_offload_prefill_resume(self, seq: Sequence) -> bool: + """True when offload already owns blocks and should resume suffix prefill. + + This avoids a second prefix lookup and, more importantly, avoids calling + ``BlockManager.allocate`` again for a sequence whose block table was + allocated before it parked for the LMCache load. + """ + return ( + self._is_offload_connector() + and ( + getattr(seq, "offload_loaded", False) + or getattr(seq, "offload_load_failed", False) + ) + and len(seq.block_table) > 0 + ) + + def _query_connector_prefill_match(self, seq: Sequence, *, skip: bool) -> bool: + """Ask the connector whether this prefill should park for remote KV.""" + if skip or self.kv_connector is None: + return False + _ext_tokens, needs_remote_load = self.kv_connector.get_num_new_matched_tokens( + seq + ) + return needs_remote_load + + def _schedule_first_decode_after_remote_kv(self, seq: Sequence) -> None: + """P/D path: a remote prefill completed, so schedule first decode.""" + seq.status = SequenceStatus.RUNNING + seq.is_first_decode = True + first_token_id = (seq.kv_transfer_params or {}).get("first_token_id") + if first_token_id is not None: + seq.append_token(first_token_id) + seq._injected_t0 = first_token_id + logger.info( + "[PD-TRANSITION] seq %s: num_tokens=%d, " + "num_prompt=%d, blocks=%d, first_token=%s, " + "last_5_tids=%s", + seq.id, + seq.num_tokens, + seq.num_prompt_tokens, + len(seq.block_table), + first_token_id, + seq.token_ids[-5:], + ) + self.running.append(seq) + + def _prefill_chunk_for_budget( + self, num_new_tokens: int, budget_remaining: int, num_batched_tokens: int + ) -> Optional[int]: + if self.enable_chunked_prefill: + return min(num_new_tokens, budget_remaining) + if num_new_tokens > budget_remaining and num_batched_tokens > 0: + return None + return num_new_tokens + + @staticmethod + def _assert_positive_prefill_chunk( + chunk: int, num_new_tokens: int, budget_remaining: int + ) -> None: + assert chunk > 0, ( + f"chunk must be positive: {chunk=}, " + f"{num_new_tokens=}, {budget_remaining=}" + ) + + def _schedule_prefill_seq( + self, + seq: Sequence, + chunk: int, + scheduled_seqs: dict[int, Sequence], + num_scheduled_tokens: list[int], + num_seqs_prefill: int, + num_batched_tokens: int, + ) -> tuple[int, int]: + num_seqs_prefill += 1 + if self.cache_stats: + self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) + num_batched_tokens += chunk + seq.status = SequenceStatus.RUNNING + seq.type = SequenceType.PREFILL + self.running.append(seq) + scheduled_seqs[seq.id] = seq + num_scheduled_tokens.append(chunk) + offload_trace( + "scheduler_prefill_scheduled", + req=seq.id, + new_tokens=chunk, + cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + offload_loaded=getattr(seq, "offload_loaded", False), + load_failed=getattr(seq, "offload_load_failed", False), + ) + return num_seqs_prefill, num_batched_tokens + + def _notify_connector_after_prefill_alloc(self, seq: Sequence) -> None: + if self.kv_connector is not None: + self.kv_connector.update_state_after_alloc(seq) + + def _confirm_remote_load_after_alloc( + self, seq: Sequence, needs_remote_load: bool + ) -> bool: + if not needs_remote_load: + return False + if hasattr(self.kv_connector, "should_park_for_load_after_alloc"): + return self.kv_connector.should_park_for_load_after_alloc(seq) + return True + + def _park_for_remote_load( + self, seq: Sequence, skipped_waiting_requests: deque[Sequence] + ) -> None: + offload_trace( + "scheduler_park_for_load", + req=seq.id, + cached=seq.num_cached_tokens, + prompt=seq.num_prompt_tokens, + blocks=len(seq.block_table), + ) + skipped_waiting_requests.append(seq) + seq.status = SequenceStatus.WAITING_FOR_REMOTE_KVS + + def _adjust_prefill_chunk_after_alloc(self, seq: Sequence, chunk: int) -> int: + if self.kv_connector is not None and hasattr( + self.kv_connector, "adjust_prefill_chunk_after_alloc" + ): + return self.kv_connector.adjust_prefill_chunk_after_alloc(seq, chunk) + return chunk + def preempt(self, seq: Sequence): seq.status = SequenceStatus.WAITING # Strip placeholder + rejected draft tokens added by postprocess. From 889248baf1e809334a6f1bb3e0a2818e3ecd3b90 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 00:56:08 -0500 Subject: [PATCH 21/27] Remove offload host fallback staging --- atom/kv_transfer/offload/connector.py | 33 +-- atom/kv_transfer/offload/gpu_connector.py | 4 +- atom/kv_transfer/offload/lmcache_compat.py | 287 +++++--------------- atom/kv_transfer/offload/lmcache_staging.py | 1 - tests/test_lmcache_offload_connector.py | 157 ++--------- 5 files changed, 97 insertions(+), 385 deletions(-) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 88869d92cd..a828c39aee 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -244,15 +244,6 @@ def _profile_enabled(self) -> bool: "off", ) - def _last_gpu_connector_fastpath(self) -> str: - gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None) - if gpu_connector is None or not hasattr(gpu_connector, "last_fastpath"): - return "unknown" - try: - return str(gpu_connector.last_fastpath()) - except Exception: - return "unknown" - def _last_gpu_connector_transfer_stats(self) -> dict[str, int | float]: gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None) if gpu_connector is None or not hasattr(gpu_connector, "last_transfer_stats"): @@ -262,12 +253,12 @@ def _last_gpu_connector_transfer_stats(self) -> dict[str, int | float]: except Exception: return {} - def _reset_gpu_connector_fastpath(self) -> None: + def _reset_gpu_connector_transfer_stats(self) -> None: gpu_connector = getattr(getattr(self, "_engine", None), "gpu_connector", None) - if gpu_connector is None or not hasattr(gpu_connector, "reset_fastpath"): + if gpu_connector is None or not hasattr(gpu_connector, "reset_transfer_stats"): return try: - gpu_connector.reset_fastpath() + gpu_connector.reset_transfer_stats() except Exception: pass @@ -327,7 +318,7 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: mask[:hbm] = False t_retrieve0 = time.perf_counter() - self._reset_gpu_connector_fastpath() + self._reset_gpu_connector_transfer_stats() ret_mask = self._engine.retrieve( torch.tensor(toks), mask=mask, @@ -335,7 +326,6 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: req_id=str(req.req_id), ) retrieve_ms = (time.perf_counter() - t_retrieve0) * 1000 - fastpath = self._last_gpu_connector_fastpath() transfer_stats = self._last_gpu_connector_transfer_stats() self._lookup_unpin(req.req_id) loaded = bool(ret_mask[hbm:lmc].all().item()) if lmc > hbm else True @@ -353,7 +343,6 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: hbm=hbm, lmc=lmc, retrieved=int(ret_mask.sum().item()), - fastpath=fastpath, chunks=transfer_stats.get("chunks", 0), groups=transfer_stats.get("groups", 0), max_chunk_bytes=transfer_stats.get("max_chunk_bytes", 0), @@ -375,8 +364,8 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: if self._profile_enabled(): logger.info( "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " - "retrieved=%d status=%s fastpath=%s chunks=%d " - "groups=%d max_chunk_bytes=%d max_group_bytes=%d " + "retrieved=%d status=%s chunks=%d groups=%d " + "max_chunk_bytes=%d max_group_bytes=%d " "gpu_staging_chunk_bytes=%d gpu_staging_group_chunks=%d " "gpu_staging_capacity_bytes=%d total_bytes=%d " "pack_ms=%.2f copy_ms=%.2f sync_ms=%.2f " @@ -388,7 +377,6 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: lmc, int(ret_mask.sum().item()), "ok" if loaded else "miss", - fastpath, int(transfer_stats.get("chunks", 0)), int(transfer_stats.get("groups", 0)), int(transfer_stats.get("max_chunk_bytes", 0)), @@ -439,7 +427,7 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: mask[:skip] = False t_store0 = time.perf_counter() - self._reset_gpu_connector_fastpath() + self._reset_gpu_connector_transfer_stats() self._engine.store( torch.tensor(toks), mask=mask, @@ -447,7 +435,6 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: req_id=str(req.req_id), ) store_ms = (time.perf_counter() - t_store0) * 1000 - fastpath = self._last_gpu_connector_fastpath() transfer_stats = self._last_gpu_connector_transfer_stats() with self._lock: self._done_save.add(req.req_id) @@ -459,7 +446,6 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: status="ok", toks=len(toks), skip=skip, - fastpath=fastpath, chunks=transfer_stats.get("chunks", 0), groups=transfer_stats.get("groups", 0), max_chunk_bytes=transfer_stats.get("max_chunk_bytes", 0), @@ -481,8 +467,8 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: if self._profile_enabled(): logger.info( "[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d skip=%d " - "fastpath=%s chunks=%d groups=%d max_chunk_bytes=%d " - "max_group_bytes=%d gpu_staging_chunk_bytes=%d " + "chunks=%d groups=%d max_chunk_bytes=%d max_group_bytes=%d " + "gpu_staging_chunk_bytes=%d " "gpu_staging_group_chunks=%d gpu_staging_capacity_bytes=%d " "total_bytes=%d pack_ms=%.2f copy_ms=%.2f sync_ms=%.2f " "transfer_ms=%.2f effective_gbps=%.2f " @@ -491,7 +477,6 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: req.req_id, len(toks), skip, - fastpath, int(transfer_stats.get("chunks", 0)), int(transfer_stats.get("groups", 0)), int(transfer_stats.get("max_chunk_bytes", 0)), diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 78bbf69327..27d8a95f99 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -97,7 +97,7 @@ def __init__(self, kv_caches: dict) -> None: except Exception: logger.warning( "ATOMKVByteCodec: Triton KV staging unavailable; " - "using chunk fallback", + "fused chunk-major staging disabled", exc_info=True, ) @@ -239,7 +239,7 @@ def gpu_to_chunk_major_device_buffer( Layout is MemoryObj-compatible: ``[chunk0: seg0 blocks | seg1 blocks | ...][chunk1: ...]``. Fused Triton staging is used when available; otherwise this method - provides a reference implementation for tests and CPU fallback. + provides a reference implementation for tests. """ groups, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups( block_id_groups, diff --git a/atom/kv_transfer/offload/lmcache_compat.py b/atom/kv_transfer/offload/lmcache_compat.py index 311f181e62..2bcbf6b1ff 100644 --- a/atom/kv_transfer/offload/lmcache_compat.py +++ b/atom/kv_transfer/offload/lmcache_compat.py @@ -122,9 +122,6 @@ def gpu_staging_capacity_bytes(self) -> int: def release_gpu_staging_after_transfer(self) -> bool: return self._release_gpu_staging_after_transfer - def _set_last_fastpath(self, value: str) -> None: - self._tls.last_fastpath = value - def _set_last_transfer_stats( self, *, @@ -158,13 +155,9 @@ def _set_last_transfer_stats( "effective_gbps": float(effective_gbps), } - def reset_fastpath(self) -> None: - self._set_last_fastpath("none") + def reset_transfer_stats(self) -> None: self._set_last_transfer_stats() - def last_fastpath(self) -> str: - return getattr(self._tls, "last_fastpath", "unknown") - def last_transfer_stats(self) -> dict[str, int | float]: return dict(getattr(self._tls, "last_transfer_stats", {})) @@ -218,44 +211,18 @@ def _release_slot_if_requested(self, slot: _StagingSlot) -> None: slot.tensor = None slot.free_event_valid = False - def _ensure_host_tmp( - self, - state: _ThreadTransferState, - nbytes: int, - ) -> torch.Tensor: - nbytes = int(nbytes) - if nbytes > self._gpu_staging_capacity_bytes: - raise RuntimeError( - "ATOM LMCache connector internal error: transfer group exceeds " - "bounded host staging capacity: " - f"nbytes={nbytes}, capacity={self._gpu_staging_capacity_bytes}" - ) - if ( - state.host_tmp is None - or int(state.host_tmp.numel()) != self._gpu_staging_capacity_bytes - ): - if state.use_cuda: - try: - state.host_tmp = torch.empty( - (self._gpu_staging_capacity_bytes,), - dtype=torch.uint8, - pin_memory=True, - ) - except RuntimeError: - state.host_tmp = torch.empty( - (self._gpu_staging_capacity_bytes,), - dtype=torch.uint8, - ) - else: - state.host_tmp = torch.empty( - (self._gpu_staging_capacity_bytes,), - dtype=torch.uint8, - ) - return state.host_tmp[:nbytes] - def _can_use_fused_chunk_major(self) -> bool: return self._use_cuda() and self.codec.has_fused_chunk_major_staging + def _require_fused_chunk_major(self) -> None: + if self._can_use_fused_chunk_major(): + return + raise RuntimeError( + "ATOM LMCache connector requires Triton fused chunk-major staging; " + "set OFFLOAD_FUSED_KV_STAGING=1 and ensure the Triton staging " + "kernel loads successfully" + ) + def _memory_tensor(self, memory_obj: Any, nbytes: int) -> torch.Tensor: tensor = getattr(memory_obj, "tensor", None) if tensor is None and hasattr(memory_obj, "get_tensor"): @@ -456,113 +423,50 @@ def batched_from_gpu( return state = self._thread_state() - use_cuda = state.use_cuda chunks = self._iter_transfer_chunks(memory_objs, block_id_groups) groups = self._iter_transfer_groups(chunks) self._record_transfer_stats(chunks, groups) if not chunks: return - if self._can_use_fused_chunk_major(): - self._set_last_fastpath("fused_chunk") - used_slots: list[_StagingSlot] = [] - pack_ms = 0.0 - copy_ms = 0.0 - sync_ms = 0.0 - t_total0 = time.perf_counter() - try: - for group in groups: - slot = self._next_slot(state) - self._remember_slot(used_slots, slot) - device_buf = self._ensure_slot(slot, group.nbytes) - if slot.free_event_valid: - state.pack_stream.wait_event(slot.free_event) - t0 = time.perf_counter() - with state.stream_ctx(state.pack_stream): - self.codec.gpu_to_chunk_major_device_buffer( - device_buf, - self._group_block_ids(group), - stream=state.pack_stream, - ) - pack_ms += (time.perf_counter() - t0) * 1000 - slot.ready_event.record(state.pack_stream) - state.copy_stream.wait_event(slot.ready_event) - t0 = time.perf_counter() - with state.stream_ctx(state.copy_stream): - self._slice_to_memory_objs(group, device_buf) - copy_ms += (time.perf_counter() - t0) * 1000 - slot.free_event.record(state.copy_stream) - slot.free_event_valid = True - t0 = time.perf_counter() - state.copy_stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - except Exception: - for slot in used_slots: - slot.free_event_valid = False - raise - finally: - self._release_slots_if_requested(used_slots) - self._record_transfer_stats( - chunks, - groups, - pack_ms=pack_ms, - copy_ms=copy_ms, - sync_ms=sync_ms, - transfer_ms=(time.perf_counter() - t_total0) * 1000, - ) - return - - self._set_last_fastpath("chunk") + self._require_fused_chunk_major() + used_slots: list[_StagingSlot] = [] pack_ms = 0.0 copy_ms = 0.0 sync_ms = 0.0 t_total0 = time.perf_counter() - used_slots: list[_StagingSlot] = [] - for group in groups: - slot = self._next_slot(state) - self._remember_slot(used_slots, slot) - device_buf = self._ensure_slot(slot, group.nbytes) - host_buf = self._ensure_host_tmp(state, group.nbytes) - try: - if use_cuda: - if slot.free_event_valid: - state.pack_stream.wait_event(slot.free_event) - t0 = time.perf_counter() - with state.stream_ctx(state.pack_stream): - self.codec.gpu_to_chunk_major_device_buffer( - device_buf, - self._group_block_ids(group), - stream=state.pack_stream, - ) - pack_ms += (time.perf_counter() - t0) * 1000 - slot.ready_event.record(state.pack_stream) - state.copy_stream.wait_event(slot.ready_event) - t0 = time.perf_counter() - with state.stream_ctx(state.copy_stream): - host_buf.copy_(device_buf, non_blocking=True) - copy_ms += (time.perf_counter() - t0) * 1000 - slot.free_event.record(state.copy_stream) - slot.free_event_valid = True - t0 = time.perf_counter() - state.copy_stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - else: - t0 = time.perf_counter() + try: + for group in groups: + slot = self._next_slot(state) + self._remember_slot(used_slots, slot) + device_buf = self._ensure_slot(slot, group.nbytes) + if slot.free_event_valid: + state.pack_stream.wait_event(slot.free_event) + t0 = time.perf_counter() + with state.stream_ctx(state.pack_stream): self.codec.gpu_to_chunk_major_device_buffer( device_buf, self._group_block_ids(group), + stream=state.pack_stream, ) - pack_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() - host_buf.copy_(device_buf, non_blocking=False) - copy_ms += (time.perf_counter() - t0) * 1000 + pack_ms += (time.perf_counter() - t0) * 1000 + slot.ready_event.record(state.pack_stream) + state.copy_stream.wait_event(slot.ready_event) t0 = time.perf_counter() - self._slice_to_memory_objs(group, host_buf) + with state.stream_ctx(state.copy_stream): + self._slice_to_memory_objs(group, device_buf) copy_ms += (time.perf_counter() - t0) * 1000 - except Exception: + slot.free_event.record(state.copy_stream) + slot.free_event_valid = True + t0 = time.perf_counter() + state.copy_stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + except Exception: + for slot in used_slots: slot.free_event_valid = False - raise - self._release_slots_if_requested(used_slots) + raise + finally: + self._release_slots_if_requested(used_slots) self._record_transfer_stats( chunks, groups, @@ -590,113 +494,50 @@ def batched_to_gpu( return state = self._thread_state() - use_cuda = state.use_cuda chunks = self._iter_transfer_chunks(memory_objs, block_id_groups) groups = self._iter_transfer_groups(chunks) self._record_transfer_stats(chunks, groups) if not chunks: return - if self._can_use_fused_chunk_major(): - self._set_last_fastpath("fused_chunk") - used_slots: list[_StagingSlot] = [] - copy_ms = 0.0 - pack_ms = 0.0 - sync_ms = 0.0 - t_total0 = time.perf_counter() - try: - for group in groups: - slot = self._next_slot(state) - self._remember_slot(used_slots, slot) - device_buf = self._ensure_slot(slot, group.nbytes) - if slot.free_event_valid: - state.copy_stream.wait_event(slot.free_event) - t0 = time.perf_counter() - with state.stream_ctx(state.copy_stream): - self._memory_objs_to_slice(group, device_buf) - copy_ms += (time.perf_counter() - t0) * 1000 - slot.ready_event.record(state.copy_stream) - state.pack_stream.wait_event(slot.ready_event) - t0 = time.perf_counter() - with state.stream_ctx(state.pack_stream): - self.codec.chunk_major_device_buffer_to_gpu( - device_buf, - self._group_block_ids(group), - stream=state.pack_stream, - ) - pack_ms += (time.perf_counter() - t0) * 1000 - slot.free_event.record(state.pack_stream) - slot.free_event_valid = True - t0 = time.perf_counter() - state.pack_stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - except Exception: - for slot in used_slots: - slot.free_event_valid = False - raise - finally: - self._release_slots_if_requested(used_slots) - self._record_transfer_stats( - chunks, - groups, - pack_ms=pack_ms, - copy_ms=copy_ms, - sync_ms=sync_ms, - transfer_ms=(time.perf_counter() - t_total0) * 1000, - ) - return - - self._set_last_fastpath("chunk") + self._require_fused_chunk_major() + used_slots: list[_StagingSlot] = [] copy_ms = 0.0 pack_ms = 0.0 sync_ms = 0.0 t_total0 = time.perf_counter() - used_slots: list[_StagingSlot] = [] - for group in groups: - slot = self._next_slot(state) - self._remember_slot(used_slots, slot) - device_buf = self._ensure_slot(slot, group.nbytes) - host_buf = self._ensure_host_tmp(state, group.nbytes) - try: + try: + for group in groups: + slot = self._next_slot(state) + self._remember_slot(used_slots, slot) + device_buf = self._ensure_slot(slot, group.nbytes) + if slot.free_event_valid: + state.copy_stream.wait_event(slot.free_event) t0 = time.perf_counter() - self._memory_objs_to_slice(group, host_buf) + with state.stream_ctx(state.copy_stream): + self._memory_objs_to_slice(group, device_buf) copy_ms += (time.perf_counter() - t0) * 1000 - if use_cuda: - if slot.free_event_valid: - state.copy_stream.wait_event(slot.free_event) - t0 = time.perf_counter() - with state.stream_ctx(state.copy_stream): - device_buf.copy_(host_buf, non_blocking=True) - copy_ms += (time.perf_counter() - t0) * 1000 - slot.ready_event.record(state.copy_stream) - state.pack_stream.wait_event(slot.ready_event) - t0 = time.perf_counter() - with state.stream_ctx(state.pack_stream): - self.codec.chunk_major_device_buffer_to_gpu( - device_buf, - self._group_block_ids(group), - stream=state.pack_stream, - ) - pack_ms += (time.perf_counter() - t0) * 1000 - slot.free_event.record(state.pack_stream) - slot.free_event_valid = True - t0 = time.perf_counter() - state.pack_stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - else: - t0 = time.perf_counter() - device_buf.copy_(host_buf, non_blocking=False) - copy_ms += (time.perf_counter() - t0) * 1000 - t0 = time.perf_counter() + slot.ready_event.record(state.copy_stream) + state.pack_stream.wait_event(slot.ready_event) + t0 = time.perf_counter() + with state.stream_ctx(state.pack_stream): self.codec.chunk_major_device_buffer_to_gpu( device_buf, self._group_block_ids(group), + stream=state.pack_stream, ) - pack_ms += (time.perf_counter() - t0) * 1000 - except Exception: + pack_ms += (time.perf_counter() - t0) * 1000 + slot.free_event.record(state.pack_stream) + slot.free_event_valid = True + t0 = time.perf_counter() + state.pack_stream.synchronize() + sync_ms += (time.perf_counter() - t0) * 1000 + except Exception: + for slot in used_slots: slot.free_event_valid = False - raise - self._release_slots_if_requested(used_slots) + raise + finally: + self._release_slots_if_requested(used_slots) self._record_transfer_stats( chunks, groups, diff --git a/atom/kv_transfer/offload/lmcache_staging.py b/atom/kv_transfer/offload/lmcache_staging.py index 4404503cda..e3bb497eb8 100644 --- a/atom/kv_transfer/offload/lmcache_staging.py +++ b/atom/kv_transfer/offload/lmcache_staging.py @@ -71,7 +71,6 @@ def __init__( self.pack_stream = None self.copy_stream = None self.next_slot = 0 - self.host_tmp: torch.Tensor | None = None if use_cuda: with torch.cuda.device(device): self.pack_stream = torch.cuda.Stream() diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 268218efd8..d45aa27be4 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -153,7 +153,7 @@ def test_lmcache_connector_maps_token_ranges_to_block_ids(): if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - original = { + kv_caches = { "l0": SimpleNamespace( k_cache=torch.arange(6 * 2, dtype=torch.uint8).reshape(6, 2), v_cache=(torch.arange(6 * 3, dtype=torch.uint8).reshape(6, 3) + 51), @@ -161,41 +161,25 @@ def test_lmcache_connector_maps_token_ranges_to_block_ids(): v_scale=None, ) } - kv_caches = { - "l0": SimpleNamespace( - k_cache=original["l0"].k_cache.clone(), - v_cache=original["l0"].v_cache.clone(), - k_scale=None, - v_scale=None, - ) - } codec = ATOMKVByteCodec(kv_caches) connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=8) - memory_obj = SimpleNamespace( - tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) - ) - connector.batched_from_gpu( - [memory_obj], + assert connector._ranges_to_block_ids( [4], [12], block_ids=[0, 1, 2, 3, 4, 5], - ) - - kv_caches["l0"].k_cache.zero_() - kv_caches["l0"].v_cache.zero_() - connector.batched_to_gpu( - [memory_obj], - [4], - [12], + ) == [[1, 2]] + assert connector._ranges_to_block_ids( + [0, 8], + [8, 16], block_ids=[0, 1, 2, 3, 4, 5], - ) - - for bid in [1, 2]: - assert torch.equal(kv_caches["l0"].k_cache[bid], original["l0"].k_cache[bid]) - assert torch.equal(kv_caches["l0"].v_cache[bid], original["l0"].v_cache[bid]) - assert torch.count_nonzero(kv_caches["l0"].k_cache[0]) == 0 - assert torch.count_nonzero(kv_caches["l0"].v_cache[0]) == 0 + ) == [[0, 1], [2, 3]] + with pytest.raises(ValueError, match="block-aligned"): + connector._ranges_to_block_ids( + [2], + [8], + block_ids=[0, 1, 2, 3, 4, 5], + ) def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): @@ -323,7 +307,6 @@ def stream_ctx(self, stream): original["l0"].v_cache[[3]].reshape(-1), ] ) - assert connector.last_fastpath() == "fused_chunk" transfer_stats = connector.last_transfer_stats() assert transfer_stats["chunks"] == 2 assert transfer_stats["groups"] == 1 @@ -350,7 +333,6 @@ def stream_ctx(self, stream): block_ids=[0, 1, 2, 3, 4, 5], ) - assert connector.last_fastpath() == "fused_chunk" assert unpack_groups == [[[1, 2], [3]]] for bid in [1, 2, 3]: assert torch.equal(kv_caches["l0"].k_cache[bid], original["l0"].k_cache[bid]) @@ -359,100 +341,12 @@ def stream_ctx(self, stream): assert torch.count_nonzero(kv_caches["l0"].v_cache[0]) == 0 -def test_lmcache_connector_fallback_staging_is_chunk_bounded(monkeypatch): +def test_lmcache_connector_requires_fused_chunk_major_staging(): import torch if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "1") - original = { - "l0": SimpleNamespace( - k_cache=torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2), - v_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 51), - k_scale=None, - v_scale=None, - ) - } - kv_caches = { - "l0": SimpleNamespace( - k_cache=original["l0"].k_cache.clone(), - v_cache=original["l0"].v_cache.clone(), - k_scale=None, - v_scale=None, - ) - } - codec = ATOMKVByteCodec(kv_caches) - connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=8) - cap = 2 * codec.bytes_per_block - slot_requests = [] - host_requests = [] - orig_ensure_slot = connector._ensure_slot - orig_ensure_host_tmp = connector._ensure_host_tmp - - def _ensure_slot(slot, nbytes): - device_buf = orig_ensure_slot(slot, nbytes) - slot_requests.append((nbytes, int(slot.tensor.numel()))) - return device_buf - - def _ensure_host_tmp(state, nbytes): - host_buf = orig_ensure_host_tmp(state, nbytes) - host_requests.append((nbytes, int(state.host_tmp.numel()))) - return host_buf - - monkeypatch.setattr(connector, "_ensure_slot", _ensure_slot) - monkeypatch.setattr(connector, "_ensure_host_tmp", _ensure_host_tmp) - memory_objs = [ - SimpleNamespace( - tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) - ), - SimpleNamespace( - tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) - ), - SimpleNamespace( - tensor=torch.empty(1 * codec.bytes_per_block, dtype=torch.uint8) - ), - ] - - connector.batched_from_gpu( - memory_objs, - [0, 8, 16], - [8, 16, 20], - block_ids=list(range(8)), - ) - - assert connector.last_fastpath() == "chunk" - assert connector.last_transfer_stats()["chunks"] == 3 - assert connector.last_transfer_stats()["max_chunk_bytes"] == cap - assert all(nbytes <= cap for nbytes, _ in slot_requests) - assert all(capacity == cap for _, capacity in slot_requests) - assert all(nbytes <= cap for nbytes, _ in host_requests) - assert all(capacity == cap for _, capacity in host_requests) - - kv_caches["l0"].k_cache.zero_() - kv_caches["l0"].v_cache.zero_() - connector.batched_to_gpu( - memory_objs, - [0, 8, 16], - [8, 16, 20], - block_ids=list(range(8)), - ) - - for bid in range(5): - assert torch.equal(kv_caches["l0"].k_cache[bid], original["l0"].k_cache[bid]) - assert torch.equal(kv_caches["l0"].v_cache[bid], original["l0"].v_cache[bid]) - assert torch.count_nonzero(kv_caches["l0"].k_cache[5]) == 0 - assert torch.count_nonzero(kv_caches["l0"].v_cache[5]) == 0 - - -def test_lmcache_connector_release_covers_fallback_chunks(monkeypatch): - import torch - - if not hasattr(torch, "arange"): - pytest.skip("real torch is unavailable") - - monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "1") - monkeypatch.setenv("OFFLOAD_RELEASE_GPU_STAGING_AFTER_TRANSFER", "1") kv_caches = { "l0": SimpleNamespace( k_cache=torch.arange(4 * 2, dtype=torch.uint8).reshape(4, 2), @@ -466,23 +360,16 @@ def test_lmcache_connector_release_covers_fallback_chunks(monkeypatch): memory_objs = [ SimpleNamespace( tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) - ), - SimpleNamespace( - tensor=torch.empty(2 * codec.bytes_per_block, dtype=torch.uint8) - ), + ) ] - connector.batched_from_gpu( - memory_objs, - [0, 8], - [8, 16], - block_ids=list(range(4)), - ) - - state = connector._thread_state() - assert state.host_tmp is not None - assert int(state.host_tmp.numel()) == 2 * codec.bytes_per_block - assert all(slot.tensor is None for slot in state.slots) + with pytest.raises(RuntimeError, match="requires Triton fused"): + connector.batched_from_gpu( + memory_objs, + [0], + [8], + block_ids=list(range(4)), + ) def test_lmcache_connector_rejects_oversized_memory_obj(): From 53eb02af51dc9120c04bed44c0ddcb5609601264 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 01:24:14 -0500 Subject: [PATCH 22/27] Rename ATOM LMCache offload modules --- .../offload/{lmcache_compat.py => atom_lmcache.py} | 4 ++-- .../offload/{lmcache_staging.py => atom_lmcache_staging.py} | 0 atom/kv_transfer/offload/connector.py | 4 ++-- atom/kv_transfer/offload/gpu_connector.py | 2 +- tests/test_lmcache_offload_connector.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) rename atom/kv_transfer/offload/{lmcache_compat.py => atom_lmcache.py} (99%) rename atom/kv_transfer/offload/{lmcache_staging.py => atom_lmcache_staging.py} (100%) diff --git a/atom/kv_transfer/offload/lmcache_compat.py b/atom/kv_transfer/offload/atom_lmcache.py similarity index 99% rename from atom/kv_transfer/offload/lmcache_compat.py rename to atom/kv_transfer/offload/atom_lmcache.py index 2bcbf6b1ff..c2fded8837 100644 --- a/atom/kv_transfer/offload/lmcache_compat.py +++ b/atom/kv_transfer/offload/atom_lmcache.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -"""LMCache-compatible raw-byte connector for ATOM offload. +"""ATOM LMCache raw-byte connector for offload. This module lets ATOM use LMCache ``CacheEngine.store()`` / ``CacheEngine.retrieve()`` without adopting LMCache's vLLM token-major KV @@ -20,7 +20,7 @@ import torch from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec -from atom.kv_transfer.offload.lmcache_staging import ( +from atom.kv_transfer.offload.atom_lmcache_staging import ( _StagingSlot, _ThreadTransferState, _env_flag, diff --git a/atom/kv_transfer/offload/lmcache_staging.py b/atom/kv_transfer/offload/atom_lmcache_staging.py similarity index 100% rename from atom/kv_transfer/offload/lmcache_staging.py rename to atom/kv_transfer/offload/atom_lmcache_staging.py diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index a828c39aee..f54863f7b7 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -39,7 +39,7 @@ from atom.kv_transfer.disaggregation.types import KVConnectorOutput, ReqId from atom.kv_transfer.offload import config as offcfg from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec -from atom.kv_transfer.offload.lmcache_compat import ATOMLMCacheGPUConnector +from atom.kv_transfer.offload.atom_lmcache import ATOMLMCacheGPUConnector from atom.kv_transfer.offload.metadata import ( ATOMRawBytesLMCacheMetadata, LMCacheOffloadMetadata, @@ -76,7 +76,7 @@ def __init__(self, config) -> None: # parked seq is waiting for it) never queues behind a backlog of fire- # and-forget saves (Phase 4 root cause: with one shared serial daemon, a # reload sat behind ~N filler saves -> request hung well past timeout). - # The LMCache-compatible GPU connector owns per-thread staging streams. + # The ATOM LMCache GPU connector owns per-thread staging streams. # OFFLOAD_COPY_WORKERS tunes the SAVE pool only. n_save_workers = int(os.environ.get("OFFLOAD_COPY_WORKERS", "1")) self._load_executor = ThreadPoolExecutor( diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index 27d8a95f99..bf6dc96fd3 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -186,7 +186,7 @@ def gpu_to_device_buffer( The staging layout is always segment-major: ``[seg0 blocks | seg1 blocks | ...]``. This is the layout consumed by - the LMCache-compatible connector before it copies the bytes to a + the ATOM LMCache connector before it copies the bytes to a ``MemoryObj``. """ block_ids = self._normalize_block_ids(block_ids) diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index d45aa27be4..7b677005d6 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -21,7 +21,7 @@ LMCacheOffloadConnectorScheduler, ) from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec -from atom.kv_transfer.offload.lmcache_compat import ATOMLMCacheGPUConnector +from atom.kv_transfer.offload.atom_lmcache import ATOMLMCacheGPUConnector from atom.kv_transfer.offload.metadata import ATOMRawBytesLMCacheMetadata from atom.model_engine.scheduler import Scheduler From 47f9ca6646ad2f83bee97876faf89b8b69af13d7 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 02:18:50 -0500 Subject: [PATCH 23/27] Remove obsolete offload staging switches --- atom/kv_transfer/offload/atom_lmcache.py | 4 ++-- atom/kv_transfer/offload/atom_lmcache_staging.py | 1 - atom/kv_transfer/offload/connector.py | 3 +-- atom/kv_transfer/offload/gpu_connector.py | 15 ++------------- atom/kv_transfer/offload/triton_kv_staging.py | 5 ----- tests/test_lmcache_offload_connector.py | 1 - 6 files changed, 5 insertions(+), 24 deletions(-) diff --git a/atom/kv_transfer/offload/atom_lmcache.py b/atom/kv_transfer/offload/atom_lmcache.py index c2fded8837..444ad29604 100644 --- a/atom/kv_transfer/offload/atom_lmcache.py +++ b/atom/kv_transfer/offload/atom_lmcache.py @@ -219,8 +219,8 @@ def _require_fused_chunk_major(self) -> None: return raise RuntimeError( "ATOM LMCache connector requires Triton fused chunk-major staging; " - "set OFFLOAD_FUSED_KV_STAGING=1 and ensure the Triton staging " - "kernel loads successfully" + "ensure KV tensors are on CUDA/HIP and the Triton staging kernel " + "loads successfully" ) def _memory_tensor(self, memory_obj: Any, nbytes: int) -> torch.Tensor: diff --git a/atom/kv_transfer/offload/atom_lmcache_staging.py b/atom/kv_transfer/offload/atom_lmcache_staging.py index e3bb497eb8..af29ff12fe 100644 --- a/atom/kv_transfer/offload/atom_lmcache_staging.py +++ b/atom/kv_transfer/offload/atom_lmcache_staging.py @@ -67,7 +67,6 @@ def __init__( staging_slots: int, ) -> None: self.device = device - self.use_cuda = use_cuda self.pack_stream = None self.copy_stream = None self.next_slot = 0 diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index f54863f7b7..3e25f058f4 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -168,14 +168,13 @@ def _logged_lookup(*a, **k): logger.info( "LMCache offload worker rank=%d: bytes_per_block=%d chunk=%d " - "codec_layout=%s gpu_staging_slots=%d " + "gpu_staging_slots=%d " "gpu_staging_chunk_bytes=%d gpu_staging_group_chunks=%d " "gpu_staging_capacity_bytes=%d release_gpu_staging=%s " "save=%s load=%s", rank, self._codec.bytes_per_block, self.chunk_size, - "chunk_major", gpu_connector.staging_slots, gpu_connector.gpu_staging_chunk_bytes, gpu_connector.gpu_staging_group_chunks, diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/gpu_connector.py index bf6dc96fd3..3846350e59 100644 --- a/atom/kv_transfer/offload/gpu_connector.py +++ b/atom/kv_transfer/offload/gpu_connector.py @@ -30,7 +30,6 @@ import logging import operator -import os import torch @@ -38,7 +37,7 @@ class ATOMKVByteCodec: - """Per-block byte mover between paged GPU KV tensors and a flat host buffer.""" + """Per-block byte mover between paged GPU KV tensors and flat buffers.""" def __init__(self, kv_caches: dict) -> None: """``kv_caches``: ordered ``{layer_name: KVCacheTensor}`` from @@ -79,20 +78,10 @@ def __init__(self, kv_caches: dict) -> None: ] self.bytes_per_block: int = sum(self._seg_block_bytes) self._fused_kv_staging = None - fused_env = os.environ.get( - "OFFLOAD_FUSED_KV_STAGING", - os.environ.get("OFFLOAD_NATIVE_KV_STAGING", "0"), - ) - if self._device.type == "cuda" and fused_env.lower() not in ( - "0", - "false", - "no", - "off", - ): + if self._device.type == "cuda": try: from atom.kv_transfer.offload import triton_kv_staging - triton_kv_staging.load_extension() self._fused_kv_staging = triton_kv_staging except Exception: logger.warning( diff --git a/atom/kv_transfer/offload/triton_kv_staging.py b/atom/kv_transfer/offload/triton_kv_staging.py index 384a56c182..9c48ce60b2 100644 --- a/atom/kv_transfer/offload/triton_kv_staging.py +++ b/atom/kv_transfer/offload/triton_kv_staging.py @@ -107,11 +107,6 @@ def _unpack_chunk_major_kernel( tl.store(dst, data, mask=mask) -def load_extension() -> None: - """Compatibility hook matching the old native module API.""" - return None - - def _device_i64(values: list[int], device: torch.device) -> torch.Tensor: return torch.tensor(values, dtype=torch.int64, device=device) diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 7b677005d6..64644ba20d 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -255,7 +255,6 @@ def synchronize(self) -> None: class _FakeState: def __init__(self) -> None: - self.use_cuda = True self.pack_stream = _FakeStream() self.copy_stream = _FakeStream() self.next_slot = 0 From 20f45cdd2ba87c60fb70f90c73ac0ea864684a9f Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 03:16:07 -0500 Subject: [PATCH 24/27] Align offload logging with ATOM defaults --- atom/kv_transfer/offload/connector.py | 31 +++++---------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 3e25f058f4..45578fa3b3 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -136,26 +136,6 @@ def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: self._engine.fmt = MemoryFormat.KV_2LTD self._engine.post_init() - # DEBUG: wrap engine.lookup to capture EVERY call (incl. the ones the ZMQ - # lookup_server makes on behalf of the scheduler) — args + result. - _orig_lookup = self._engine.lookup - _rk = rank - - def _logged_lookup(*a, **k): - r = _orig_lookup(*a, **k) - h = k.get("hashes") - logger.debug( - "[ENGINE.LOOKUP] rank=%s lookup_id=%s nhashes=%s first3=%s -> %s", - _rk, - k.get("lookup_id"), - (len(h) if h is not None else None), - (list(h[:3]) if h else None), - r, - ) - return r - - self._engine.lookup = _logged_lookup - # ZMQ lookup server so the scheduler process can query our hit counts. try: from lmcache.v1.lookup_client.factory import LookupClientFactory @@ -236,7 +216,7 @@ def _lookup_unpin(self, req_id) -> None: pass def _profile_enabled(self) -> bool: - return os.environ.get("OFFLOAD_PROFILE", "1").lower() not in ( + return os.environ.get("OFFLOAD_PROFILE", "0").lower() not in ( "0", "false", "no", @@ -392,7 +372,6 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: retrieve_ms, total_ms, ) - logger.info("offload _do_load DONE req=%s", req.req_id) def _do_save_req(self, req: LMCacheReqMeta) -> None: ss = req.save_spec @@ -798,7 +777,7 @@ def _maybe_start_unaligned_handoff( self._handoff_loads.add(sid) seq.offload_loaded_tokens = hbm seq.offload_handoff_boundary_tokens = boundary - logger.info( + logger.debug( "[OFFLOAD-LOAD-HANDOFF] seq=%s hbm_cached=%d boundary=%d " "lmc_cached=%d need_after_boundary=%d min_load=%d chunk=%d", seq.id, @@ -869,7 +848,7 @@ def should_park_partial_prefill_for_load(self, seq) -> bool: self._reqs_need_recv[sid] = seq self._handoff_loads.discard(sid) seq.offload_loaded_tokens = max(hbm, lmc) - logger.info( + logger.debug( "[OFFLOAD-LOAD-HANDOFF-READY] seq=%s hbm_cached=%d " "lmc_cached=%d offload_loaded=%d need=%d", seq.id, @@ -899,7 +878,7 @@ def _mark_load_skip( ) -> None: seq.offload_loaded_tokens = hbm min_load = int(getattr(self, "_min_load_tokens", 8192)) - logger.info( + logger.debug( "[OFFLOAD-LOAD-SKIP] seq=%s hbm_cached=%d lmc_cached=%d " "need=%d min_load=%d chunk=%d reason=%s", seq.id, @@ -977,7 +956,7 @@ def build_connector_meta(self) -> LMCacheOffloadMetadata: # req_id MUST be the raw seq.id (the type the scheduler compares # against in _update_waiting_for_remote_kv); str(seq.id) is only for # LMCache's lookup/pin API. A str here silently never wakes the seq. - logger.info( + logger.debug( "[OFFLOAD-LOAD-EMIT] seq=%s hbm_cached=%d lmc_cached=%d " "offload_loaded=%d need=%d min_load=%d nblocks=%d reason=aligned_large_hit", seq.id, From 4d3bb899acfc5b5129ffdae7f800dad220ea4884 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 03:41:26 -0500 Subject: [PATCH 25/27] Clarify offload connector module names --- .../offload/{gpu_connector.py => atom_kv_byte_codec.py} | 0 .../{atom_lmcache.py => atom_lmcache_gpu_connector.py} | 2 +- atom/kv_transfer/offload/connector.py | 6 ++++-- tests/test_lmcache_offload_connector.py | 6 ++++-- 4 files changed, 9 insertions(+), 5 deletions(-) rename atom/kv_transfer/offload/{gpu_connector.py => atom_kv_byte_codec.py} (100%) rename atom/kv_transfer/offload/{atom_lmcache.py => atom_lmcache_gpu_connector.py} (99%) diff --git a/atom/kv_transfer/offload/gpu_connector.py b/atom/kv_transfer/offload/atom_kv_byte_codec.py similarity index 100% rename from atom/kv_transfer/offload/gpu_connector.py rename to atom/kv_transfer/offload/atom_kv_byte_codec.py diff --git a/atom/kv_transfer/offload/atom_lmcache.py b/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py similarity index 99% rename from atom/kv_transfer/offload/atom_lmcache.py rename to atom/kv_transfer/offload/atom_lmcache_gpu_connector.py index 444ad29604..4507c71deb 100644 --- a/atom/kv_transfer/offload/atom_lmcache.py +++ b/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py @@ -19,7 +19,7 @@ import torch -from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec +from atom.kv_transfer.offload.atom_kv_byte_codec import ATOMKVByteCodec from atom.kv_transfer.offload.atom_lmcache_staging import ( _StagingSlot, _ThreadTransferState, diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 45578fa3b3..08e6966b1d 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -38,8 +38,10 @@ ) from atom.kv_transfer.disaggregation.types import KVConnectorOutput, ReqId from atom.kv_transfer.offload import config as offcfg -from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec -from atom.kv_transfer.offload.atom_lmcache import ATOMLMCacheGPUConnector +from atom.kv_transfer.offload.atom_kv_byte_codec import ATOMKVByteCodec +from atom.kv_transfer.offload.atom_lmcache_gpu_connector import ( + ATOMLMCacheGPUConnector, +) from atom.kv_transfer.offload.metadata import ( ATOMRawBytesLMCacheMetadata, LMCacheOffloadMetadata, diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 64644ba20d..3deb1d220a 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -20,8 +20,10 @@ LMCacheOffloadConnector, LMCacheOffloadConnectorScheduler, ) -from atom.kv_transfer.offload.gpu_connector import ATOMKVByteCodec -from atom.kv_transfer.offload.atom_lmcache import ATOMLMCacheGPUConnector +from atom.kv_transfer.offload.atom_kv_byte_codec import ATOMKVByteCodec +from atom.kv_transfer.offload.atom_lmcache_gpu_connector import ( + ATOMLMCacheGPUConnector, +) from atom.kv_transfer.offload.metadata import ATOMRawBytesLMCacheMetadata from atom.model_engine.scheduler import Scheduler From af36b20dd275f48647e4db5bfad4bb81af5fcf10 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 04:13:24 -0500 Subject: [PATCH 26/27] Require fused chunk-major offload staging --- .../kv_transfer/offload/atom_kv_byte_codec.py | 138 ++++-------------- .../offload/atom_lmcache_gpu_connector.py | 11 +- tests/test_lmcache_offload_connector.py | 112 +++++++------- 3 files changed, 87 insertions(+), 174 deletions(-) diff --git a/atom/kv_transfer/offload/atom_kv_byte_codec.py b/atom/kv_transfer/offload/atom_kv_byte_codec.py index 3846350e59..fd678f1b48 100644 --- a/atom/kv_transfer/offload/atom_kv_byte_codec.py +++ b/atom/kv_transfer/offload/atom_kv_byte_codec.py @@ -24,6 +24,7 @@ [ all L0.K blocks | all L0.V blocks | all L0.kS blocks | ... ] and batched transfers concatenate those per-chunk buffers for LMCache MemoryObjs. +The production path requires Triton fused chunk-major staging. """ from __future__ import annotations @@ -86,7 +87,7 @@ def __init__(self, kv_caches: dict) -> None: except Exception: logger.warning( "ATOMKVByteCodec: Triton KV staging unavailable; " - "fused chunk-major staging disabled", + "fused chunk-major staging unavailable", exc_info=True, ) @@ -99,14 +100,6 @@ def has_fused_chunk_major_staging(self) -> bool: return self._fused_kv_staging is not None # -- helpers ---------------------------------------------------------- - def _segment_bases(self, nblocks: int) -> list[int]: - bases = [] - acc = 0 - for nb in self._seg_block_bytes: - bases.append(acc) - acc += nb * nblocks - return bases - def _device_ctx(self): if self._device.type == "cuda": return torch.cuda.device(self._device) @@ -158,65 +151,7 @@ def _validate_device_buf(self, device_buf: torch.Tensor, nblocks: int) -> None: f"got {int(device_buf.numel())}" ) - @staticmethod - def _segment_bytes_matrix(seg: torch.Tensor) -> torch.Tensor: - if not seg.is_contiguous(): - raise RuntimeError("ATOMKVByteCodec: segment tensor not contiguous") - return seg.reshape(seg.shape[0], -1).view(torch.uint8) - # -- public API ------------------------------------------------------- - def gpu_to_device_buffer( - self, - device_buf: torch.Tensor, - block_ids: list[int], - stream: torch.cuda.Stream | None = None, - ) -> None: - """Gather ATOM KV blocks into a flat device staging buffer. - - The staging layout is always segment-major: - ``[seg0 blocks | seg1 blocks | ...]``. This is the layout consumed by - the ATOM LMCache connector before it copies the bytes to a - ``MemoryObj``. - """ - block_ids = self._normalize_block_ids(block_ids) - self._validate_device_buf(device_buf, len(block_ids)) - if not block_ids: - return - with self._device_ctx(): - stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() - with stream_ctx: - idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) - bases = self._segment_bases(len(block_ids)) - for seg, base, nb in zip(self._segments, bases, self._seg_block_bytes): - mat = self._segment_bytes_matrix(seg) - dst = device_buf[base : base + len(block_ids) * nb].reshape( - len(block_ids), nb - ) - torch.index_select(mat, 0, idx, out=dst) - - def device_buffer_to_gpu( - self, - device_buf: torch.Tensor, - block_ids: list[int], - stream: torch.cuda.Stream | None = None, - ) -> None: - """Scatter a segment-major device staging buffer into ATOM KV blocks.""" - block_ids = self._normalize_block_ids(block_ids) - self._validate_device_buf(device_buf, len(block_ids)) - if not block_ids: - return - with self._device_ctx(): - stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() - with stream_ctx: - idx = torch.tensor(block_ids, dtype=torch.long, device=self._device) - bases = self._segment_bases(len(block_ids)) - for seg, base, nb in zip(self._segments, bases, self._seg_block_bytes): - mat = self._segment_bytes_matrix(seg) - src = device_buf[base : base + len(block_ids) * nb].reshape( - len(block_ids), nb - ) - mat.index_copy_(0, idx, src) - def gpu_to_chunk_major_device_buffer( self, device_buf: torch.Tensor, @@ -227,39 +162,29 @@ def gpu_to_chunk_major_device_buffer( Layout is MemoryObj-compatible: ``[chunk0: seg0 blocks | seg1 blocks | ...][chunk1: ...]``. - Fused Triton staging is used when available; otherwise this method - provides a reference implementation for tests. + Fused Triton staging is required. """ - groups, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups( + _, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups( block_id_groups, reject_repeated=True, ) self._validate_device_buf(device_buf, len(flat_block_ids)) if not flat_block_ids: return + if self._fused_kv_staging is None: + raise RuntimeError( + "ATOMKVByteCodec requires Triton fused chunk-major staging" + ) with self._device_ctx(): stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() with stream_ctx: - if self._fused_kv_staging is not None: - self._fused_kv_staging.fused_pack_chunk_major( - self._segments, - self._seg_block_bytes, - chunk_block_counts, - flat_block_ids, - device_buf, - ) - return - - offset = 0 - for block_ids in groups: - nblocks = len(block_ids) - chunk_nbytes = nblocks * self.bytes_per_block - self.gpu_to_device_buffer( - device_buf[offset : offset + chunk_nbytes], - block_ids, - stream=stream, - ) - offset += chunk_nbytes + self._fused_kv_staging.fused_pack_chunk_major( + self._segments, + self._seg_block_bytes, + chunk_block_counts, + flat_block_ids, + device_buf, + ) def chunk_major_device_buffer_to_gpu( self, @@ -268,36 +193,27 @@ def chunk_major_device_buffer_to_gpu( stream: torch.cuda.Stream | None = None, ) -> None: """Scatter a chunk-major device staging buffer into ATOM KV blocks.""" - groups, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups( + _, flat_block_ids, chunk_block_counts = self._normalize_block_id_groups( block_id_groups, reject_repeated=True, ) self._validate_device_buf(device_buf, len(flat_block_ids)) if not flat_block_ids: return + if self._fused_kv_staging is None: + raise RuntimeError( + "ATOMKVByteCodec requires Triton fused chunk-major staging" + ) with self._device_ctx(): stream_ctx = torch.cuda.stream(stream) if stream is not None else _nullctx() with stream_ctx: - if self._fused_kv_staging is not None: - self._fused_kv_staging.fused_unpack_chunk_major( - device_buf, - self._segments, - self._seg_block_bytes, - chunk_block_counts, - flat_block_ids, - ) - return - - offset = 0 - for block_ids in groups: - nblocks = len(block_ids) - chunk_nbytes = nblocks * self.bytes_per_block - self.device_buffer_to_gpu( - device_buf[offset : offset + chunk_nbytes], - block_ids, - stream=stream, - ) - offset += chunk_nbytes + self._fused_kv_staging.fused_unpack_chunk_major( + device_buf, + self._segments, + self._seg_block_bytes, + chunk_block_counts, + flat_block_ids, + ) class _nullctx: diff --git a/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py b/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py index 4507c71deb..0a672301b5 100644 --- a/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py +++ b/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py @@ -211,11 +211,8 @@ def _release_slot_if_requested(self, slot: _StagingSlot) -> None: slot.tensor = None slot.free_event_valid = False - def _can_use_fused_chunk_major(self) -> bool: - return self._use_cuda() and self.codec.has_fused_chunk_major_staging - - def _require_fused_chunk_major(self) -> None: - if self._can_use_fused_chunk_major(): + def _assert_fused_chunk_major_available(self) -> None: + if self._use_cuda() and self.codec.has_fused_chunk_major_staging: return raise RuntimeError( "ATOM LMCache connector requires Triton fused chunk-major staging; " @@ -429,7 +426,7 @@ def batched_from_gpu( if not chunks: return - self._require_fused_chunk_major() + self._assert_fused_chunk_major_available() used_slots: list[_StagingSlot] = [] pack_ms = 0.0 copy_ms = 0.0 @@ -500,7 +497,7 @@ def batched_to_gpu( if not chunks: return - self._require_fused_chunk_major() + self._assert_fused_chunk_major_available() used_slots: list[_StagingSlot] = [] copy_ms = 0.0 pack_ms = 0.0 diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 3deb1d220a..900928e525 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -62,6 +62,50 @@ def _scheduler() -> LMCacheOffloadConnectorScheduler: return sched +def _install_fake_fused_chunk_major(codec: ATOMKVByteCodec) -> None: + def _pack( + segments, + seg_block_bytes, + chunk_block_counts, + flat_block_ids, + device_buf, + ) -> None: + offset = 0 + cursor = 0 + for count in chunk_block_counts: + block_ids = flat_block_ids[cursor : cursor + count] + cursor += count + idx = torch.tensor(block_ids, dtype=torch.long, device=codec.device) + for seg, nbytes in zip(segments, seg_block_bytes): + src = seg.index_select(0, idx).contiguous().view(torch.uint8) + device_buf[offset : offset + count * nbytes].copy_(src.reshape(-1)) + offset += count * nbytes + + def _unpack( + device_buf, + segments, + seg_block_bytes, + chunk_block_counts, + flat_block_ids, + ) -> None: + offset = 0 + cursor = 0 + for count in chunk_block_counts: + block_ids = flat_block_ids[cursor : cursor + count] + cursor += count + idx = torch.tensor(block_ids, dtype=torch.long, device=codec.device) + for seg, nbytes in zip(segments, seg_block_bytes): + src = device_buf[offset : offset + count * nbytes] + src = src.view(seg.dtype).reshape((count,) + tuple(seg.shape[1:])) + seg.index_copy_(0, idx, src) + offset += count * nbytes + + codec._fused_kv_staging = SimpleNamespace( + fused_pack_chunk_major=_pack, + fused_unpack_chunk_major=_unpack, + ) + + def test_raw_bytes_metadata_shapes_are_block_rounded(): import torch @@ -98,57 +142,6 @@ def test_raw_bytes_metadata_rejects_unaligned_chunk_size(): ) -def test_codec_device_buffer_roundtrip_noncontiguous_blocks(): - import torch - - if not hasattr(torch, "arange"): - pytest.skip("real torch is unavailable") - - original = { - "l0": SimpleNamespace( - k_cache=torch.arange(8 * 2 * 3, dtype=torch.uint8).reshape(8, 2, 3), - v_cache=(torch.arange(8 * 4, dtype=torch.uint8).reshape(8, 4) + 51), - k_scale=None, - v_scale=None, - ), - "l1": SimpleNamespace( - k_cache=(torch.arange(8 * 3, dtype=torch.uint8).reshape(8, 3) + 101), - v_cache=(torch.arange(8 * 2, dtype=torch.uint8).reshape(8, 2) + 151), - k_scale=None, - v_scale=None, - ), - } - kv_caches = { - name: SimpleNamespace( - k_cache=layer.k_cache.clone(), v_cache=layer.v_cache.clone() - ) - for name, layer in original.items() - } - for layer in kv_caches.values(): - layer.k_scale = None - layer.v_scale = None - - codec = ATOMKVByteCodec(kv_caches) - block_ids = [1, 3, 4, 7] - device_buf = torch.empty( - len(block_ids) * codec.bytes_per_block, - dtype=torch.uint8, - device=codec.device, - ) - - codec.gpu_to_device_buffer(device_buf, block_ids) - for layer in kv_caches.values(): - layer.k_cache.zero_() - layer.v_cache.zero_() - codec.device_buffer_to_gpu(device_buf, block_ids) - - for name, layer in kv_caches.items(): - src = original[name] - for bid in block_ids: - assert torch.equal(layer.k_cache[bid], src.k_cache[bid]) - assert torch.equal(layer.v_cache[bid], src.v_cache[bid]) - - def test_lmcache_connector_maps_token_ranges_to_block_ids(): import torch @@ -211,9 +204,10 @@ def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): } codec = ATOMKVByteCodec(kv_caches) connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=8) - monkeypatch.setattr(connector, "_can_use_fused_chunk_major", lambda: True) - orig_pack = codec.gpu_to_chunk_major_device_buffer - orig_unpack = codec.chunk_major_device_buffer_to_gpu + _install_fake_fused_chunk_major(codec) + monkeypatch.setattr( + connector, "_assert_fused_chunk_major_available", lambda: None + ) pack_groups = [] unpack_groups = [] @@ -224,7 +218,9 @@ def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): "gpu_to_chunk_major_device_buffer", lambda device_buf, block_id_groups, stream=None: ( pack_groups.append([list(group) for group in block_id_groups]), - orig_pack(device_buf, block_id_groups, stream=None), + ATOMKVByteCodec.gpu_to_chunk_major_device_buffer( + codec, device_buf, block_id_groups, stream=None + ), )[-1], ) monkeypatch.setattr( @@ -232,7 +228,9 @@ def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): "chunk_major_device_buffer_to_gpu", lambda device_buf, block_id_groups, stream=None: ( unpack_groups.append([list(group) for group in block_id_groups]), - orig_unpack(device_buf, block_id_groups, stream=None), + ATOMKVByteCodec.chunk_major_device_buffer_to_gpu( + codec, device_buf, block_id_groups, stream=None + ), )[-1], ) orig_ensure_slot = connector._ensure_slot @@ -473,6 +471,7 @@ def test_codec_chunk_major_device_buffer_layout(): ) } codec = ATOMKVByteCodec(kv_caches) + _install_fake_fused_chunk_major(codec) block_id_groups = [[0, 1], [2, 3]] device_buf = torch.empty( 4 * codec.bytes_per_block, @@ -530,6 +529,7 @@ def test_codec_chunk_major_handles_tail_and_sparse_blocks(): for name, layer in original.items() } codec = ATOMKVByteCodec(kv_caches) + _install_fake_fused_chunk_major(codec) block_id_groups = [[4, 1, 3], [0]] device_buf = torch.empty( 4 * codec.bytes_per_block, From 069996b90627e18b482a7d69e52800ae471dbe95 Mon Sep 17 00:00:00 2001 From: yihonglie Date: Thu, 4 Jun 2026 12:38:22 -0500 Subject: [PATCH 27/27] Simplify LMCache offload staging buffer --- .../offload/atom_lmcache_gpu_connector.py | 362 +++++++----------- .../offload/atom_lmcache_staging.py | 8 +- atom/kv_transfer/offload/connector.py | 40 +- tests/test_lmcache_offload_connector.py | 66 ++-- 4 files changed, 180 insertions(+), 296 deletions(-) diff --git a/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py b/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py index 0a672301b5..fc5a144ff9 100644 --- a/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py +++ b/atom/kv_transfer/offload/atom_lmcache_gpu_connector.py @@ -14,14 +14,13 @@ from dataclasses import dataclass import threading -import time -from typing import Any +from typing import Any, Callable import torch from atom.kv_transfer.offload.atom_kv_byte_codec import ATOMKVByteCodec from atom.kv_transfer.offload.atom_lmcache_staging import ( - _StagingSlot, + _StagingBuffer, _ThreadTransferState, _env_flag, _env_int, @@ -47,6 +46,18 @@ class _TransferGroup: nbytes: int +@dataclass(frozen=True) +class _PipelineStage: + """One leg of the two-stage staging pipeline. + + ``stream`` is the CUDA stream the work is issued on; ``run(group, + device_buf)`` does the work. + """ + + stream: Any + run: Callable[[_TransferGroup, torch.Tensor], None] + + class ATOMLMCacheGPUConnector: """LMCache GPUConnectorInterface for ATOM's opaque KV-block byte layout.""" @@ -79,7 +90,7 @@ def __init__( ) self.device = torch.device(codec.device) self._tls = threading.local() - requested_group_chunks = _env_int("OFFLOAD_GPU_STAGING_CHUNKS", 2) + requested_buffer_chunks = _env_int("OFFLOAD_GPU_STAGING_CHUNKS", 2) max_staging_bytes = _env_optional_int("OFFLOAD_GPU_STAGING_MAX_BYTES") if max_staging_bytes is not None: if max_staging_bytes < self._gpu_staging_chunk_bytes: @@ -89,78 +100,34 @@ def __init__( f"max_bytes={max_staging_bytes}, " f"chunk_bytes={self._gpu_staging_chunk_bytes}" ) - requested_group_chunks = min( - requested_group_chunks, + requested_buffer_chunks = min( + requested_buffer_chunks, max_staging_bytes // self._gpu_staging_chunk_bytes, ) - self._staging_group_chunks = max(1, int(requested_group_chunks)) - self._gpu_staging_capacity_bytes = ( - self._staging_group_chunks * self._gpu_staging_chunk_bytes + self._staging_buffer_chunks = max(1, int(requested_buffer_chunks)) + self._gpu_staging_buffer_bytes = ( + self._staging_buffer_chunks * self._gpu_staging_chunk_bytes ) - self._staging_slots = _env_int("OFFLOAD_GPU_STAGING_SLOTS", 1) self._release_gpu_staging_after_transfer = _env_flag( "OFFLOAD_RELEASE_GPU_STAGING_AFTER_TRANSFER" ) - @property - def staging_slots(self) -> int: - return self._staging_slots - @property def gpu_staging_chunk_bytes(self) -> int: return self._gpu_staging_chunk_bytes @property - def gpu_staging_group_chunks(self) -> int: - return self._staging_group_chunks + def gpu_staging_buffer_chunks(self) -> int: + return self._staging_buffer_chunks @property - def gpu_staging_capacity_bytes(self) -> int: - return self._gpu_staging_capacity_bytes + def gpu_staging_buffer_bytes(self) -> int: + return self._gpu_staging_buffer_bytes @property def release_gpu_staging_after_transfer(self) -> bool: return self._release_gpu_staging_after_transfer - def _set_last_transfer_stats( - self, - *, - chunks: int = 0, - max_chunk_bytes: int = 0, - groups: int = 0, - max_group_bytes: int = 0, - total_bytes: int = 0, - pack_ms: float = 0.0, - copy_ms: float = 0.0, - sync_ms: float = 0.0, - transfer_ms: float = 0.0, - ) -> None: - effective_gbps = 0.0 - if transfer_ms > 0 and total_bytes > 0: - effective_gbps = total_bytes / (transfer_ms * 1_000_000.0) - self._tls.last_transfer_stats = { - "chunks": int(chunks), - "max_chunk_bytes": int(max_chunk_bytes), - "groups": int(groups), - "max_group_bytes": int(max_group_bytes), - "total_bytes": int(total_bytes), - "gpu_staging_chunk_bytes": self._gpu_staging_chunk_bytes, - "gpu_staging_group_chunks": self._staging_group_chunks, - "gpu_staging_capacity_bytes": self._gpu_staging_capacity_bytes, - "gpu_staging_slots": self._staging_slots, - "pack_ms": float(pack_ms), - "copy_ms": float(copy_ms), - "sync_ms": float(sync_ms), - "transfer_ms": float(transfer_ms), - "effective_gbps": float(effective_gbps), - } - - def reset_transfer_stats(self) -> None: - self._set_last_transfer_stats() - - def last_transfer_stats(self) -> dict[str, int | float]: - return dict(getattr(self._tls, "last_transfer_stats", {})) - def _use_cuda(self) -> bool: return self.device.type == "cuda" @@ -175,41 +142,42 @@ def _thread_state(self) -> _ThreadTransferState: state = _ThreadTransferState( self.device, self._use_cuda(), - self._staging_slots, ) states[key] = state return state - def _ensure_slot(self, slot: _StagingSlot, nbytes: int) -> torch.Tensor: + def _ensure_staging_buffer( + self, + staging_buffer: _StagingBuffer, + nbytes: int, + ) -> torch.Tensor: nbytes = int(nbytes) - if nbytes > self._gpu_staging_capacity_bytes: + if nbytes > self._gpu_staging_buffer_bytes: raise RuntimeError( "ATOM LMCache connector internal error: transfer group exceeds " - "bounded GPU staging capacity: " - f"nbytes={nbytes}, capacity={self._gpu_staging_capacity_bytes}" + "bounded GPU staging buffer: " + f"nbytes={nbytes}, capacity={self._gpu_staging_buffer_bytes}" ) if ( - slot.tensor is None - or int(slot.tensor.numel()) != self._gpu_staging_capacity_bytes + staging_buffer.tensor is None + or int(staging_buffer.tensor.numel()) != self._gpu_staging_buffer_bytes ): - slot.tensor = torch.empty( - (self._gpu_staging_capacity_bytes,), + staging_buffer.tensor = torch.empty( + (self._gpu_staging_buffer_bytes,), dtype=torch.uint8, device=self.device, ) - slot.free_event_valid = False - return slot.tensor[:nbytes] - - def _next_slot(self, state: _ThreadTransferState) -> _StagingSlot: - slot = state.slots[state.next_slot % len(state.slots)] - state.next_slot += 1 - return slot + staging_buffer.free_event_valid = False + return staging_buffer.tensor[:nbytes] - def _release_slot_if_requested(self, slot: _StagingSlot) -> None: + def _release_staging_buffer_if_requested( + self, + staging_buffer: _StagingBuffer, + ) -> None: if not self._release_gpu_staging_after_transfer: return - slot.tensor = None - slot.free_event_valid = False + staging_buffer.tensor = None + staging_buffer.free_event_valid = False def _assert_fused_chunk_major_available(self) -> None: if self._use_cuda() and self.codec.has_fused_chunk_major_staging: @@ -324,9 +292,9 @@ def _iter_transfer_groups( current: list[_TransferChunk] = [] current_bytes = 0 for chunk in chunks: - would_exceed_count = len(current) >= self._staging_group_chunks + would_exceed_count = len(current) >= self._staging_buffer_chunks would_exceed_bytes = ( - current_bytes + chunk.nbytes > self._gpu_staging_capacity_bytes + current_bytes + chunk.nbytes > self._gpu_staging_buffer_bytes ) if current and (would_exceed_count or would_exceed_bytes): groups.append(_TransferGroup(chunks=current, nbytes=current_bytes)) @@ -338,31 +306,6 @@ def _iter_transfer_groups( groups.append(_TransferGroup(chunks=current, nbytes=current_bytes)) return groups - def _record_transfer_stats( - self, - chunks: list[_TransferChunk], - groups: list[_TransferGroup] | None = None, - *, - pack_ms: float = 0.0, - copy_ms: float = 0.0, - sync_ms: float = 0.0, - transfer_ms: float = 0.0, - ) -> None: - if groups is None: - groups = [] - total_bytes = sum(chunk.nbytes for chunk in chunks) - self._set_last_transfer_stats( - chunks=len(chunks), - max_chunk_bytes=max((chunk.nbytes for chunk in chunks), default=0), - groups=len(groups), - max_group_bytes=max((group.nbytes for group in groups), default=0), - total_bytes=total_bytes, - pack_ms=pack_ms, - copy_ms=copy_ms, - sync_ms=sync_ms, - transfer_ms=transfer_ms, - ) - @staticmethod def _group_block_ids(group: _TransferGroup) -> list[list[int]]: return [chunk.block_ids for chunk in group.chunks] @@ -387,16 +330,65 @@ def _memory_objs_to_slice(group: _TransferGroup, dst_buf: torch.Tensor) -> None: ) offset += chunk.nbytes - @staticmethod - def _remember_slot(used_slots: list[_StagingSlot], slot: _StagingSlot) -> None: - if not any(existing is slot for existing in used_slots): - used_slots.append(slot) + def _prepare_transfer( + self, + memory_objs: list[Any] | None, + starts: list[int] | None, + ends: list[int] | None, + **kwargs, + ) -> tuple[_ThreadTransferState, list[_TransferGroup]] | None: + """Validate inputs and build the chunk/group transfer plan.""" + if memory_objs is None or starts is None or ends is None: + raise ValueError("memory_objs, starts, and ends are required") + if not (len(memory_objs) == len(starts) == len(ends)): + raise ValueError("memory_objs, starts, and ends must have equal length") + block_id_groups = self._ranges_to_block_ids(starts, ends, **kwargs) + if not memory_objs: + return None + state = self._thread_state() + chunks = self._iter_transfer_chunks(memory_objs, block_id_groups) + if not chunks: + return None + return state, self._iter_transfer_groups(chunks) - def _release_slots_if_requested(self, used_slots: list[_StagingSlot]) -> None: - if not self._release_gpu_staging_after_transfer: - return - for slot in used_slots: - self._release_slot_if_requested(slot) + def _run_staged_pipeline( + self, + state: _ThreadTransferState, + groups: list[_TransferGroup], + stage_a: _PipelineStage, + stage_b: _PipelineStage, + ) -> None: + """Drive an event-synced two-stage staging pipeline. + + Each group flows ``stage_a`` -> ``stage_b`` on their respective streams, + handed off via the staging buffer's ready event; the free event gates a + later group's reuse of the same buffer. ``stage_b``'s stream produces + the observable result, so it is the one synchronized at the end. + """ + self._assert_fused_chunk_major_available() + staging_buffer = state.staging_buffer + used_buffer = False + try: + for group in groups: + device_buf = self._ensure_staging_buffer(staging_buffer, group.nbytes) + used_buffer = True + if staging_buffer.free_event_valid: + stage_a.stream.wait_event(staging_buffer.free_event) + with state.stream_ctx(stage_a.stream): + stage_a.run(group, device_buf) + staging_buffer.ready_event.record(stage_a.stream) + stage_b.stream.wait_event(staging_buffer.ready_event) + with state.stream_ctx(stage_b.stream): + stage_b.run(group, device_buf) + staging_buffer.free_event.record(stage_b.stream) + staging_buffer.free_event_valid = True + stage_b.stream.synchronize() + except Exception: + staging_buffer.free_event_valid = False + raise + finally: + if used_buffer: + self._release_staging_buffer_if_requested(staging_buffer) def from_gpu(self, memory_obj: Any, start: int, end: int, **kwargs) -> None: self.batched_from_gpu([memory_obj], [start], [end], **kwargs) @@ -412,65 +404,23 @@ def batched_from_gpu( **kwargs, ) -> None: """Pack ATOM KV blocks to LMCache MemoryObjs via bounded staging.""" - if not (len(memory_objs) == len(starts) == len(ends)): - raise ValueError("memory_objs, starts, and ends must have equal length") - block_id_groups = self._ranges_to_block_ids(starts, ends, **kwargs) - if not memory_objs: - self._set_last_transfer_stats() - return - - state = self._thread_state() - chunks = self._iter_transfer_chunks(memory_objs, block_id_groups) - groups = self._iter_transfer_groups(chunks) - self._record_transfer_stats(chunks, groups) - if not chunks: + prepared = self._prepare_transfer(memory_objs, starts, ends, **kwargs) + if prepared is None: return - - self._assert_fused_chunk_major_available() - used_slots: list[_StagingSlot] = [] - pack_ms = 0.0 - copy_ms = 0.0 - sync_ms = 0.0 - t_total0 = time.perf_counter() - try: - for group in groups: - slot = self._next_slot(state) - self._remember_slot(used_slots, slot) - device_buf = self._ensure_slot(slot, group.nbytes) - if slot.free_event_valid: - state.pack_stream.wait_event(slot.free_event) - t0 = time.perf_counter() - with state.stream_ctx(state.pack_stream): - self.codec.gpu_to_chunk_major_device_buffer( - device_buf, - self._group_block_ids(group), - stream=state.pack_stream, - ) - pack_ms += (time.perf_counter() - t0) * 1000 - slot.ready_event.record(state.pack_stream) - state.copy_stream.wait_event(slot.ready_event) - t0 = time.perf_counter() - with state.stream_ctx(state.copy_stream): - self._slice_to_memory_objs(group, device_buf) - copy_ms += (time.perf_counter() - t0) * 1000 - slot.free_event.record(state.copy_stream) - slot.free_event_valid = True - t0 = time.perf_counter() - state.copy_stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - except Exception: - for slot in used_slots: - slot.free_event_valid = False - raise - finally: - self._release_slots_if_requested(used_slots) - self._record_transfer_stats( - chunks, + state, groups = prepared + self._run_staged_pipeline( + state, groups, - pack_ms=pack_ms, - copy_ms=copy_ms, - sync_ms=sync_ms, - transfer_ms=(time.perf_counter() - t_total0) * 1000, + stage_a=_PipelineStage( + state.pack_stream, + lambda group, buf: self.codec.gpu_to_chunk_major_device_buffer( + buf, self._group_block_ids(group), stream=state.pack_stream + ), + ), + stage_b=_PipelineStage( + state.copy_stream, + lambda group, buf: self._slice_to_memory_objs(group, buf), + ), ) def batched_to_gpu( @@ -481,65 +431,21 @@ def batched_to_gpu( **kwargs, ) -> None: """Load LMCache MemoryObjs back into ATOM KV blocks via bounded staging.""" - if memory_objs is None or starts is None or ends is None: - raise ValueError("memory_objs, starts, and ends are required") - if not (len(memory_objs) == len(starts) == len(ends)): - raise ValueError("memory_objs, starts, and ends must have equal length") - block_id_groups = self._ranges_to_block_ids(starts, ends, **kwargs) - if not memory_objs: - self._set_last_transfer_stats() + prepared = self._prepare_transfer(memory_objs, starts, ends, **kwargs) + if prepared is None: return - - state = self._thread_state() - chunks = self._iter_transfer_chunks(memory_objs, block_id_groups) - groups = self._iter_transfer_groups(chunks) - self._record_transfer_stats(chunks, groups) - if not chunks: - return - - self._assert_fused_chunk_major_available() - used_slots: list[_StagingSlot] = [] - copy_ms = 0.0 - pack_ms = 0.0 - sync_ms = 0.0 - t_total0 = time.perf_counter() - try: - for group in groups: - slot = self._next_slot(state) - self._remember_slot(used_slots, slot) - device_buf = self._ensure_slot(slot, group.nbytes) - if slot.free_event_valid: - state.copy_stream.wait_event(slot.free_event) - t0 = time.perf_counter() - with state.stream_ctx(state.copy_stream): - self._memory_objs_to_slice(group, device_buf) - copy_ms += (time.perf_counter() - t0) * 1000 - slot.ready_event.record(state.copy_stream) - state.pack_stream.wait_event(slot.ready_event) - t0 = time.perf_counter() - with state.stream_ctx(state.pack_stream): - self.codec.chunk_major_device_buffer_to_gpu( - device_buf, - self._group_block_ids(group), - stream=state.pack_stream, - ) - pack_ms += (time.perf_counter() - t0) * 1000 - slot.free_event.record(state.pack_stream) - slot.free_event_valid = True - t0 = time.perf_counter() - state.pack_stream.synchronize() - sync_ms += (time.perf_counter() - t0) * 1000 - except Exception: - for slot in used_slots: - slot.free_event_valid = False - raise - finally: - self._release_slots_if_requested(used_slots) - self._record_transfer_stats( - chunks, + state, groups = prepared + self._run_staged_pipeline( + state, groups, - pack_ms=pack_ms, - copy_ms=copy_ms, - sync_ms=sync_ms, - transfer_ms=(time.perf_counter() - t_total0) * 1000, + stage_a=_PipelineStage( + state.copy_stream, + lambda group, buf: self._memory_objs_to_slice(group, buf), + ), + stage_b=_PipelineStage( + state.pack_stream, + lambda group, buf: self.codec.chunk_major_device_buffer_to_gpu( + buf, self._group_block_ids(group), stream=state.pack_stream + ), + ), ) diff --git a/atom/kv_transfer/offload/atom_lmcache_staging.py b/atom/kv_transfer/offload/atom_lmcache_staging.py index af29ff12fe..b5ef500150 100644 --- a/atom/kv_transfer/offload/atom_lmcache_staging.py +++ b/atom/kv_transfer/offload/atom_lmcache_staging.py @@ -18,7 +18,7 @@ def __exit__(self, *args): return False -class _StagingSlot: +class _StagingBuffer: def __init__(self, use_cuda: bool) -> None: self.tensor: torch.Tensor | None = None self.ready_event = None @@ -64,19 +64,15 @@ def __init__( self, device: torch.device, use_cuda: bool, - staging_slots: int, ) -> None: self.device = device self.pack_stream = None self.copy_stream = None - self.next_slot = 0 if use_cuda: with torch.cuda.device(device): self.pack_stream = torch.cuda.Stream() self.copy_stream = torch.cuda.Stream() - self.slots = [_StagingSlot(use_cuda) for _ in range(staging_slots)] - else: - self.slots = [_StagingSlot(use_cuda) for _ in range(staging_slots)] + self.staging_buffer = _StagingBuffer(use_cuda) def stream_ctx(self, stream): if stream is None: diff --git a/atom/kv_transfer/offload/connector.py b/atom/kv_transfer/offload/connector.py index 08e6966b1d..a29a8040bf 100644 --- a/atom/kv_transfer/offload/connector.py +++ b/atom/kv_transfer/offload/connector.py @@ -150,17 +150,15 @@ def register_kv_caches(self, kv_caches: dict, transfer_tensors=None) -> None: logger.info( "LMCache offload worker rank=%d: bytes_per_block=%d chunk=%d " - "gpu_staging_slots=%d " - "gpu_staging_chunk_bytes=%d gpu_staging_group_chunks=%d " - "gpu_staging_capacity_bytes=%d release_gpu_staging=%s " + "gpu_staging_chunk_bytes=%d gpu_staging_buffer_chunks=%d " + "gpu_staging_buffer_bytes=%d release_gpu_staging=%s " "save=%s load=%s", rank, self._codec.bytes_per_block, self.chunk_size, - gpu_connector.staging_slots, gpu_connector.gpu_staging_chunk_bytes, - gpu_connector.gpu_staging_group_chunks, - gpu_connector.gpu_staging_capacity_bytes, + gpu_connector.gpu_staging_buffer_chunks, + gpu_connector.gpu_staging_buffer_bytes, gpu_connector.release_gpu_staging_after_transfer, self._do_save, self._do_load, @@ -329,9 +327,11 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: max_chunk_bytes=transfer_stats.get("max_chunk_bytes", 0), max_group_bytes=transfer_stats.get("max_group_bytes", 0), gpu_staging_chunk_bytes=transfer_stats.get("gpu_staging_chunk_bytes", 0), - gpu_staging_group_chunks=transfer_stats.get("gpu_staging_group_chunks", 0), - gpu_staging_capacity_bytes=transfer_stats.get( - "gpu_staging_capacity_bytes", 0 + gpu_staging_buffer_chunks=transfer_stats.get( + "gpu_staging_buffer_chunks", 0 + ), + gpu_staging_buffer_bytes=transfer_stats.get( + "gpu_staging_buffer_bytes", 0 ), total_bytes=transfer_stats.get("total_bytes", 0), pack_ms=f"{float(transfer_stats.get('pack_ms', 0.0)):.2f}", @@ -347,8 +347,8 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: "[OFFLOAD-LOAD-PROF] rank=%s req=%s hbm=%d lmc=%d " "retrieved=%d status=%s chunks=%d groups=%d " "max_chunk_bytes=%d max_group_bytes=%d " - "gpu_staging_chunk_bytes=%d gpu_staging_group_chunks=%d " - "gpu_staging_capacity_bytes=%d total_bytes=%d " + "gpu_staging_chunk_bytes=%d gpu_staging_buffer_chunks=%d " + "gpu_staging_buffer_bytes=%d total_bytes=%d " "pack_ms=%.2f copy_ms=%.2f sync_ms=%.2f " "transfer_ms=%.2f effective_gbps=%.2f " "retrieve_ms=%.2f total_ms=%.2f", @@ -363,8 +363,8 @@ def _do_load_req(self, req: LMCacheReqMeta) -> None: int(transfer_stats.get("max_chunk_bytes", 0)), int(transfer_stats.get("max_group_bytes", 0)), int(transfer_stats.get("gpu_staging_chunk_bytes", 0)), - int(transfer_stats.get("gpu_staging_group_chunks", 0)), - int(transfer_stats.get("gpu_staging_capacity_bytes", 0)), + int(transfer_stats.get("gpu_staging_buffer_chunks", 0)), + int(transfer_stats.get("gpu_staging_buffer_bytes", 0)), int(transfer_stats.get("total_bytes", 0)), float(transfer_stats.get("pack_ms", 0.0)), float(transfer_stats.get("copy_ms", 0.0)), @@ -431,9 +431,11 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: max_chunk_bytes=transfer_stats.get("max_chunk_bytes", 0), max_group_bytes=transfer_stats.get("max_group_bytes", 0), gpu_staging_chunk_bytes=transfer_stats.get("gpu_staging_chunk_bytes", 0), - gpu_staging_group_chunks=transfer_stats.get("gpu_staging_group_chunks", 0), - gpu_staging_capacity_bytes=transfer_stats.get( - "gpu_staging_capacity_bytes", 0 + gpu_staging_buffer_chunks=transfer_stats.get( + "gpu_staging_buffer_chunks", 0 + ), + gpu_staging_buffer_bytes=transfer_stats.get( + "gpu_staging_buffer_bytes", 0 ), total_bytes=transfer_stats.get("total_bytes", 0), pack_ms=f"{float(transfer_stats.get('pack_ms', 0.0)):.2f}", @@ -449,7 +451,7 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: "[OFFLOAD-SAVE-PROF] rank=%s req=%s toks=%d skip=%d " "chunks=%d groups=%d max_chunk_bytes=%d max_group_bytes=%d " "gpu_staging_chunk_bytes=%d " - "gpu_staging_group_chunks=%d gpu_staging_capacity_bytes=%d " + "gpu_staging_buffer_chunks=%d gpu_staging_buffer_bytes=%d " "total_bytes=%d pack_ms=%.2f copy_ms=%.2f sync_ms=%.2f " "transfer_ms=%.2f effective_gbps=%.2f " "store_ms=%.2f total_ms=%.2f", @@ -462,8 +464,8 @@ def _do_save_req(self, req: LMCacheReqMeta) -> None: int(transfer_stats.get("max_chunk_bytes", 0)), int(transfer_stats.get("max_group_bytes", 0)), int(transfer_stats.get("gpu_staging_chunk_bytes", 0)), - int(transfer_stats.get("gpu_staging_group_chunks", 0)), - int(transfer_stats.get("gpu_staging_capacity_bytes", 0)), + int(transfer_stats.get("gpu_staging_buffer_chunks", 0)), + int(transfer_stats.get("gpu_staging_buffer_bytes", 0)), int(transfer_stats.get("total_bytes", 0)), float(transfer_stats.get("pack_ms", 0.0)), float(transfer_stats.get("copy_ms", 0.0)), diff --git a/tests/test_lmcache_offload_connector.py b/tests/test_lmcache_offload_connector.py index 900928e525..e5fd0140e6 100644 --- a/tests/test_lmcache_offload_connector.py +++ b/tests/test_lmcache_offload_connector.py @@ -211,7 +211,7 @@ def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): pack_groups = [] unpack_groups = [] - slot_requests = [] + buffer_requests = [] monkeypatch.setattr( codec, @@ -233,14 +233,14 @@ def test_lmcache_connector_fused_chunk_fastpath_uses_chunk_major(monkeypatch): ), )[-1], ) - orig_ensure_slot = connector._ensure_slot + orig_ensure_staging_buffer = connector._ensure_staging_buffer - def _ensure_slot(slot, nbytes): - device_buf = orig_ensure_slot(slot, nbytes) - slot_requests.append((nbytes, int(slot.tensor.numel()))) + def _ensure_staging_buffer(staging_buffer, nbytes): + device_buf = orig_ensure_staging_buffer(staging_buffer, nbytes) + buffer_requests.append((nbytes, int(staging_buffer.tensor.numel()))) return device_buf - monkeypatch.setattr(connector, "_ensure_slot", _ensure_slot) + monkeypatch.setattr(connector, "_ensure_staging_buffer", _ensure_staging_buffer) class _FakeEvent: def record(self, stream) -> None: @@ -257,21 +257,12 @@ class _FakeState: def __init__(self) -> None: self.pack_stream = _FakeStream() self.copy_stream = _FakeStream() - self.next_slot = 0 - self.slots = [ - SimpleNamespace( - tensor=None, - ready_event=_FakeEvent(), - free_event=_FakeEvent(), - free_event_valid=False, - ), - SimpleNamespace( - tensor=None, - ready_event=_FakeEvent(), - free_event=_FakeEvent(), - free_event_valid=False, - ), - ] + self.staging_buffer = SimpleNamespace( + tensor=None, + ready_event=_FakeEvent(), + free_event=_FakeEvent(), + free_event_valid=False, + ) def stream_ctx(self, stream): return nullcontext() @@ -306,20 +297,11 @@ def stream_ctx(self, stream): original["l0"].v_cache[[3]].reshape(-1), ] ) - transfer_stats = connector.last_transfer_stats() - assert transfer_stats["chunks"] == 2 - assert transfer_stats["groups"] == 1 - assert transfer_stats["max_chunk_bytes"] == 2 * codec.bytes_per_block - assert transfer_stats["max_group_bytes"] == 3 * codec.bytes_per_block - assert transfer_stats["total_bytes"] == 3 * codec.bytes_per_block - assert transfer_stats["gpu_staging_chunk_bytes"] == 2 * codec.bytes_per_block - assert transfer_stats["gpu_staging_group_chunks"] == 2 - assert transfer_stats["gpu_staging_capacity_bytes"] == 4 * codec.bytes_per_block - assert transfer_stats["gpu_staging_slots"] == 1 - assert transfer_stats["transfer_ms"] >= 0 assert pack_groups == [[[1, 2], [3]]] - assert all(nbytes <= 4 * codec.bytes_per_block for nbytes, _ in slot_requests) - assert all(capacity == 4 * codec.bytes_per_block for _, capacity in slot_requests) + assert all(nbytes <= 4 * codec.bytes_per_block for nbytes, _ in buffer_requests) + assert all( + capacity == 4 * codec.bytes_per_block for _, capacity in buffer_requests + ) assert torch.equal(memory_objs[0].tensor, expected0) assert torch.equal(memory_objs[1].tensor, expected1) @@ -400,13 +382,12 @@ def test_lmcache_connector_rejects_oversized_memory_obj(): ) -def test_lmcache_connector_respects_staging_slot_env(monkeypatch): +def test_lmcache_connector_respects_staging_buffer_chunks_env(monkeypatch): import torch if not hasattr(torch, "arange"): pytest.skip("real torch is unavailable") - monkeypatch.setenv("OFFLOAD_GPU_STAGING_SLOTS", "2") monkeypatch.setenv("OFFLOAD_GPU_STAGING_CHUNKS", "3") kv_caches = { "l0": SimpleNamespace( @@ -419,13 +400,12 @@ def test_lmcache_connector_respects_staging_slot_env(monkeypatch): codec = ATOMKVByteCodec(kv_caches) connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=4) - assert connector.staging_slots == 2 - assert connector.gpu_staging_group_chunks == 3 - assert connector.gpu_staging_capacity_bytes == 3 * connector.gpu_staging_chunk_bytes - assert len(connector._thread_state().slots) == 2 + assert connector.gpu_staging_buffer_chunks == 3 + assert connector.gpu_staging_buffer_bytes == 3 * connector.gpu_staging_chunk_bytes + assert connector._thread_state().staging_buffer.tensor is None -def test_lmcache_connector_default_staging_group_chunks_is_two(monkeypatch): +def test_lmcache_connector_default_staging_buffer_chunks_is_two(monkeypatch): import torch if not hasattr(torch, "arange"): @@ -444,8 +424,8 @@ def test_lmcache_connector_default_staging_group_chunks_is_two(monkeypatch): codec = ATOMKVByteCodec(kv_caches) connector = ATOMLMCacheGPUConnector(codec, block_size=4, chunk_size=4) - assert connector.gpu_staging_group_chunks == 2 - assert connector.gpu_staging_capacity_bytes == 2 * connector.gpu_staging_chunk_bytes + assert connector.gpu_staging_buffer_chunks == 2 + assert connector.gpu_staging_buffer_bytes == 2 * connector.gpu_staging_chunk_bytes def test_codec_chunk_major_device_buffer_layout():