From 56567d04a921679c08c6d3c2442fbfc31c605feb Mon Sep 17 00:00:00 2001 From: bongwoobak Date: Tue, 23 Jun 2026 03:56:43 +0900 Subject: [PATCH] feat(kv-events): add block token_offset, sequence numbers, and replay Build on #869 so external KV-event consumers can align and trust the stream: - BlockStored.token_offset: sequence position of the run's first block, so a consumer maps block i to [token_offset + i*block_size, ...). - ZmqEventPublisher tags each batch with a monotonic sequence number and sends [topic, seq, payload]; subscribers can detect dropped batches as seq gaps. - Optional replay ROUTER socket + ring buffer: a subscriber requests missed batches by start sequence number (ATOM_KV_EVENTS_REPLAY_ENDPOINT). Wire change: PUB frames go from single-frame to [topic, seq, payload]; consumers must use recv_multipart(). Validated on MI300 (Qwen3-0.6B). --- atom/config.py | 4 + atom/distributed/kv_events.py | 87 ++++++++++-- atom/model_engine/block_manager.py | 16 ++- atom/model_engine/scheduler.py | 2 + atom/utils/envs.py | 4 + tests/test_kv_events.py | 208 ++++++++++++++++++++++++++++- 6 files changed, 307 insertions(+), 14 deletions(-) diff --git a/atom/config.py b/atom/config.py index 6e764b8e1a..641594e989 100644 --- a/atom/config.py +++ b/atom/config.py @@ -956,6 +956,9 @@ class KVEventsConfig: publisher: str = "zmq" # "null" | "zmq" endpoint: str = "tcp://127.0.0.1:5557" topic: str = "" + # ROUTER endpoint subscribers use to request replay of missed batches by + # sequence number. Empty string keeps replay disabled (PUB-only). + replay_endpoint: str = "" # ZMQ high-water-mark on the PUB socket (0 = unlimited). hwm: int = 0 # Bounded in-process queue between scheduler and sender thread. When full, @@ -972,6 +975,7 @@ def from_env(cls) -> "KVEventsConfig": publisher=envs.ATOM_KV_EVENTS_PUBLISHER, endpoint=envs.ATOM_KV_EVENTS_ENDPOINT, topic=envs.ATOM_KV_EVENTS_TOPIC, + replay_endpoint=envs.ATOM_KV_EVENTS_REPLAY_ENDPOINT, hwm=envs.ATOM_KV_EVENTS_HWM, buffer_steps=envs.ATOM_KV_EVENTS_BUFFER_STEPS, ) diff --git a/atom/distributed/kv_events.py b/atom/distributed/kv_events.py index 5dd4830bd3..c69d898612 100644 --- a/atom/distributed/kv_events.py +++ b/atom/distributed/kv_events.py @@ -6,11 +6,13 @@ from __future__ import annotations +import itertools import logging import queue import threading import time from abc import ABC, abstractmethod +from collections import deque from collections.abc import Iterable from typing import Any, Final @@ -50,6 +52,10 @@ class BlockStored(KVCacheEvent): # Reserved wire slots; emitted as None until hybrid-cache wiring lands. kv_cache_spec_kind: str | None = None kv_cache_spec_sliding_window: int | None = None + # ATOM extension (trailing, so strict vLLM array_like consumers ignore it): + # sequence position of the first token of the first block in this run. + # With block_size, block i covers [token_offset + i*block_size, +block_size). + token_offset: int | None = None class BlockRemoved(KVCacheEvent): @@ -134,10 +140,18 @@ class ZmqEventPublisher(EventPublisher): batch is dropped — KV events are advisory and a missed eviction is cheaper than stalling inference. - With `topic=""` (default) each message is a single msgpack frame; with a - non-empty `topic` the publisher sends two-frame multipart messages - (`[topic, payload]`), so consumers must use `recv_multipart()` to read - the payload. + Every message is a three-frame multipart `[topic, seq, payload]`, where + `topic` is the (possibly empty) subscription key, `seq` is a monotonic + 8-byte big-endian batch counter, and `payload` is the msgpack-encoded + EventBatch. Consumers must use `recv_multipart()`. The sequence number + lets a subscriber detect batches it missed (e.g. a slow/late SUB that the + PUB socket dropped at the transport level) and, when a `replay_endpoint` + is configured, request them back from the in-memory replay buffer. + + Note: `seq` is assigned at send time, so it counts only batches that left + the sender. Batches dropped from the internal queue on overflow (slow + encoder) are advisory losses tracked in `stats['dropped']`; they consume + no sequence number and are not replayable. """ def __init__( @@ -147,6 +161,7 @@ def __init__( topic: str = "", hwm: int = 0, buffer_steps: int = 10_000, + replay_endpoint: str = "", data_parallel_rank: int | None = None, encoder: msgspec.msgpack.Encoder | None = None, ) -> None: @@ -170,8 +185,21 @@ def __init__( self._socket.bind(endpoint) self._zmq_error_cls = zmq.ZMQError # captured so _run doesn't re-import + # Optional replay: a ROUTER socket + ring buffer of recently-sent + # batches. A subscriber that detects a seq gap can request everything + # from a start sequence number and get the buffered batches back. The + # ROUTER is created here but used only by the sender thread. + self._replay = None + self._replay_buffer: deque[tuple[int, bytes, bytes]] | None = None + if replay_endpoint: + self._replay = ctx.socket(zmq.ROUTER) + self._replay.bind(replay_endpoint) + self._replay_buffer = deque(maxlen=buffer_steps) + + self._seq_gen = itertools.count() self._drops = 0 self._sent = 0 + self._replayed = 0 self._encode_errors = 0 self._closing = False self._lock = threading.Lock() @@ -239,23 +267,61 @@ def shutdown(self) -> None: self._socket.close(linger=linger) except Exception: # pragma: no cover pass + if self._replay is not None: + try: + self._replay.close(linger=0) + except Exception: # pragma: no cover + pass # --- internal --- def _run(self) -> None: + # Poll the replay socket between sends. When replay is disabled the + # queue.get() blocks (timeout=None); when enabled it wakes periodically + # so replay requests are serviced even while no events are flowing. + get_timeout = 0.05 if self._replay is not None else None while True: - item = self._queue.get() + if self._replay is not None and self._replay.poll(0): + try: + self._service_replay() + except Exception: # pragma: no cover - replay is non-critical + logger.exception("KV event replay request failed") + try: + item = self._queue.get(timeout=get_timeout) + except queue.Empty: + continue if item is None: return try: - if self._topic_bytes: - self._socket.send_multipart([self._topic_bytes, item]) - else: - self._socket.send(item) + seq = next(self._seq_gen) + seq_bytes = seq.to_bytes(8, "big") + self._socket.send_multipart([self._topic_bytes, seq_bytes, item]) + if self._replay_buffer is not None: + self._replay_buffer.append((seq, seq_bytes, item)) with self._lock: self._sent += 1 except self._zmq_error_cls: # pragma: no cover - socket closed return + def _service_replay(self) -> None: + """Answer a pending replay request: resend every buffered batch with + seq >= the requested start sequence. Request frame is + `[client_id, (delim,) start_seq]`; we echo the routing prefix back.""" + frames = self._replay.recv_multipart() + if len(frames) < 2: + logger.warning("KV event replay: malformed request %r", frames) + return + try: + start_seq = int.from_bytes(frames[-1], "big") + except Exception: + logger.warning("KV event replay: bad start_seq %r", frames[-1]) + return + prefix = frames[:-1] # [client_id] or [client_id, empty_delim] + for seq, seq_bytes, payload in list(self._replay_buffer or ()): + if seq >= start_seq: + self._replay.send_multipart([*prefix, seq_bytes, payload]) + with self._lock: + self._replayed += 1 + # Test/diagnostic hooks. @property def stats(self) -> dict[str, int]: @@ -263,6 +329,7 @@ def stats(self) -> dict[str, int]: return { "sent": self._sent, "dropped": self._drops, + "replayed": self._replayed, "encode_errors": self._encode_errors, } @@ -275,6 +342,7 @@ def make_publisher( topic: str = "", hwm: int = 0, buffer_steps: int = 10_000, + replay_endpoint: str = "", data_parallel_rank: int | None = None, ) -> EventPublisher: """Construct a publisher from plain-config args. Returns `NullEventPublisher` @@ -287,6 +355,7 @@ def make_publisher( topic=topic, hwm=hwm, buffer_steps=buffer_steps, + replay_endpoint=replay_endpoint, data_parallel_rank=data_parallel_rank, ) raise ValueError(f"unknown KV event publisher: {publisher_kind!r}") diff --git a/atom/model_engine/block_manager.py b/atom/model_engine/block_manager.py index 4ea5c14948..891dab6639 100644 --- a/atom/model_engine/block_manager.py +++ b/atom/model_engine/block_manager.py @@ -40,14 +40,21 @@ def _make_block_stored( parent: int | None, block_size: int, medium: str = MEDIUM_GPU, + token_offset: int | None = None, ) -> BlockStored: - """Construct a BlockStored event from a coalesced run of new blocks.""" + """Construct a BlockStored event from a coalesced run of new blocks. + + `token_offset` is the sequence position of the first token of the run's + first block, so consumers can map block i to + `[token_offset + i*block_size, token_offset + (i+1)*block_size)`. + """ return BlockStored( block_hashes=hashes, parent_block_hash=parent, token_ids=tokens, block_size=block_size, medium=medium, + token_offset=token_offset, ) @@ -251,6 +258,7 @@ def hash_blocks(self, seq: Sequence, num_new_tokens: int) -> None: store_run_tokens, store_run_parent, self.block_size, + token_offset=start * self.block_size, ) ) @@ -326,11 +334,14 @@ def record_remote_store( block_hashes: list[int], token_ids: list[int], parent_block_hash: int | None = None, + token_offset: int | None = None, ) -> None: """Emit a BlockStored(medium=REMOTE) for blocks received from a remote KV transfer producer (Mooncake/MoriIO decode side). Called by the KVConnector worker once the transfer completes so external KV-cache - consumers (LMCache, etc.) can track remote-resident blocks.""" + consumers (LMCache, etc.) can track remote-resident blocks. + + `token_offset` is the sequence position of the first remote block.""" if self._event_log is None or not block_hashes: return self._event_log.append( @@ -340,5 +351,6 @@ def record_remote_store( parent_block_hash, self.block_size, medium=MEDIUM_REMOTE, + token_offset=token_offset, ) ) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index dab098f1a9..0dfdddfab2 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -483,6 +483,7 @@ def __init__(self, config: Config): topic=kv_events_cfg.topic, hwm=kv_events_cfg.hwm, buffer_steps=kv_events_cfg.buffer_steps, + replay_endpoint=kv_events_cfg.replay_endpoint, data_parallel_rank=dp_rank, ) logger.info( @@ -1334,6 +1335,7 @@ def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool: block_hashes=remote_hashes, token_ids=remote_tokens, parent_block_hash=parent_block_hash, + token_offset=num_cached_blocks * bm.block_size, ) return True diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 87b00cba74..ff4c131843 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -189,6 +189,10 @@ "ATOM_KV_EVENTS_ENDPOINT", "tcp://127.0.0.1:5557" ), "ATOM_KV_EVENTS_TOPIC": lambda: os.getenv("ATOM_KV_EVENTS_TOPIC", ""), + # ROUTER endpoint for the replay socket; empty string disables replay. + "ATOM_KV_EVENTS_REPLAY_ENDPOINT": lambda: os.getenv( + "ATOM_KV_EVENTS_REPLAY_ENDPOINT", "" + ), "ATOM_KV_EVENTS_HWM": lambda: int(os.getenv("ATOM_KV_EVENTS_HWM", "0") or "0"), "ATOM_KV_EVENTS_BUFFER_STEPS": lambda: int( os.getenv("ATOM_KV_EVENTS_BUFFER_STEPS", "10000") or "10000" diff --git a/tests/test_kv_events.py b/tests/test_kv_events.py index c752fac771..be277935dc 100644 --- a/tests/test_kv_events.py +++ b/tests/test_kv_events.py @@ -75,6 +75,32 @@ def test_block_stored_roundtrip(self): assert dec.medium == MEDIUM_GPU assert dec.block_size == 4 + def test_block_stored_token_offset_roundtrip(self): + # token_offset records the sequence position the first block of the run + # covers, so consumers can map each block to [offset + i*block_size, ...). + evt = BlockStored( + block_hashes=[111, 222], + parent_block_hash=None, + token_ids=[1, 2, 3, 4, 5, 6, 7, 8], + block_size=4, + token_offset=16, + ) + enc = msgspec.msgpack.Encoder().encode(evt) + dec = msgspec.msgpack.Decoder(BlockStored).decode(enc) + assert dec.token_offset == 16 + + def test_block_stored_token_offset_defaults_none(self): + evt = BlockStored( + block_hashes=[1], + parent_block_hash=None, + token_ids=[1, 2, 3, 4], + block_size=4, + ) + dec = msgspec.msgpack.Decoder(BlockStored).decode( + msgspec.msgpack.Encoder().encode(evt) + ) + assert dec.token_offset is None + def test_block_removed_roundtrip(self): evt = BlockRemoved(block_hashes=[111], medium=MEDIUM_GPU) enc = msgspec.msgpack.Encoder().encode(evt) @@ -150,6 +176,29 @@ def test_block_stored_on_first_allocate(self, seq_factory): assert stored[0].block_size == 4 assert stored[0].medium == MEDIUM_GPU + def test_block_stored_first_run_offset_is_zero(self, seq_factory): + bm = _bm_with_events() + seq = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + _admit(bm, seq) + stored = [e for e in bm.take_events() if isinstance(e, BlockStored)] + assert len(stored) == 1 + assert stored[0].token_offset == 0 + + def test_block_stored_offset_after_cached_prefix(self, seq_factory): + # s2 reuses s1's first two blocks (tokens 1-8) and adds a third + # block (tokens 9-12). The new run starts at block index 2, so its + # token_offset must be 2 * block_size == 8. + bm = _bm_with_events() + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + _admit(bm, s1) + bm.take_events() + + s2 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + _admit(bm, s2) + stored = [e for e in bm.take_events() if isinstance(e, BlockStored)] + assert len(stored) == 1 + assert stored[0].token_offset == 8 + def test_drain_is_destructive(self, seq_factory): bm = _bm_with_events() seq = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) @@ -238,6 +287,18 @@ def test_record_remote_store(self, seq_factory): assert events[0].medium == MEDIUM_REMOTE assert events[0].block_hashes == [42, 43] + def test_record_remote_store_carries_token_offset(self, seq_factory): + bm = _bm_with_events() + bm.record_remote_store( + block_hashes=[42, 43], + token_ids=[1, 2, 3, 4, 5, 6, 7, 8], + parent_block_hash=None, + token_offset=16, + ) + events = bm.take_events() + assert len(events) == 1 + assert events[0].token_offset == 16 + def test_record_remote_store_no_op_when_disabled(self, block_manager): # block_manager fixture has events disabled block_manager.record_remote_store(block_hashes=[1], token_ids=[0]) @@ -265,6 +326,40 @@ def test_make_publisher_unknown_kind_raises(self): with pytest.raises(ValueError): make_publisher(enabled=True, publisher_kind="kafka", endpoint="") + +class TestReplayEndpointWiring: + def test_make_publisher_forwards_replay_endpoint(self): + pytest.importorskip("zmq") + pub = make_publisher( + enabled=True, + publisher_kind="zmq", + endpoint="inproc://mp-replay-pub", + replay_endpoint="inproc://mp-replay-router", + ) + try: + assert pub._replay is not None + assert pub._replay_buffer is not None + finally: + pub.shutdown() + + def test_make_publisher_no_replay_by_default(self): + pytest.importorskip("zmq") + pub = make_publisher( + enabled=True, publisher_kind="zmq", endpoint="inproc://mp-noreplay" + ) + try: + assert pub._replay is None + finally: + pub.shutdown() + + def test_env_replay_endpoint_default_and_override(self, monkeypatch): + import atom.utils.envs as envs + + monkeypatch.delenv("ATOM_KV_EVENTS_REPLAY_ENDPOINT", raising=False) + assert envs.ATOM_KV_EVENTS_REPLAY_ENDPOINT == "" + monkeypatch.setenv("ATOM_KV_EVENTS_REPLAY_ENDPOINT", "tcp://127.0.0.1:5558") + assert envs.ATOM_KV_EVENTS_REPLAY_ENDPOINT == "tcp://127.0.0.1:5558" + def test_zmq_publisher_roundtrip(self): # Skip cleanly when pyzmq isn't installed (zmq is an optional dep of # the publisher, not of the engine). @@ -280,13 +375,18 @@ def test_zmq_publisher_roundtrip(self): sub.setsockopt(zmq.SUBSCRIBE, b"") sub.connect(endpoint) decoder = msgspec.msgpack.Decoder(EventBatch) - payload: bytes | None = None + frames: list[bytes] | None = None for _ in range(10): pub.publish([BlockRemoved(block_hashes=[7])]) if sub.poll(timeout=200): - payload = sub.recv() + frames = sub.recv_multipart() break - assert payload is not None, "SUB did not receive any batch" + assert frames is not None, "SUB did not receive any batch" + # Wire layout is [topic, seq, payload]; topic is empty by default. + assert len(frames) == 3 + topic, seq_bytes, payload = frames + assert topic == b"" + assert int.from_bytes(seq_bytes, "big") == 0 # first batch batch = decoder.decode(payload) assert len(batch.events) == 1 assert isinstance(batch.events[0], BlockRemoved) @@ -294,6 +394,108 @@ def test_zmq_publisher_roundtrip(self): sub.close(linger=0) pub.shutdown() + def test_zmq_publisher_seq_is_monotonic(self): + zmq = pytest.importorskip("zmq") + endpoint = "inproc://test-kv-events-seq" + pub = ZmqEventPublisher(endpoint=endpoint, buffer_steps=64) + ctx = zmq.Context.instance() + sub = ctx.socket(zmq.SUB) + try: + sub.setsockopt(zmq.SUBSCRIBE, b"") + sub.connect(endpoint) + seqs: list[int] = [] + for i in range(5): + pub.publish([BlockRemoved(block_hashes=[i])]) + deadline_polls = 50 + while len(seqs) < 5 and deadline_polls > 0: + if sub.poll(timeout=200): + _, seq_bytes, _ = sub.recv_multipart() + seqs.append(int.from_bytes(seq_bytes, "big")) + deadline_polls -= 1 + assert seqs == [0, 1, 2, 3, 4] + finally: + sub.close(linger=0) + pub.shutdown() + + def test_replay_recovers_missed_batches(self): + zmq = pytest.importorskip("zmq") + pub_ep = "inproc://test-kv-replay-pub" + replay_ep = "inproc://test-kv-replay-router" + pub = ZmqEventPublisher( + endpoint=pub_ep, replay_endpoint=replay_ep, buffer_steps=64 + ) + ctx = zmq.Context.instance() + sub = ctx.socket(zmq.SUB) + try: + sub.setsockopt(zmq.SUBSCRIBE, b"") + sub.connect(pub_ep) + for i in range(3): + pub.publish([BlockRemoved(block_hashes=[i])]) + received = 0 + polls = 50 + while received < 3 and polls > 0: + if sub.poll(timeout=200): + sub.recv_multipart() + received += 1 + polls -= 1 + assert received == 3 + + dealer = ctx.socket(zmq.DEALER) + dealer.connect(replay_ep) + dealer.send(b"\x00" * 8) # start_seq = 0 + replayed_seqs: list[int] = [] + replay_polls = 50 + while len(replayed_seqs) < 3 and replay_polls > 0: + if dealer.poll(timeout=200): + seq_bytes, payload = dealer.recv_multipart() + replayed_seqs.append(int.from_bytes(seq_bytes, "big")) + msgspec.msgpack.Decoder(EventBatch).decode(payload) + replay_polls -= 1 + dealer.close(linger=0) + assert replayed_seqs == [0, 1, 2] + finally: + sub.close(linger=0) + pub.shutdown() + + def test_replay_only_returns_from_start_seq(self): + zmq = pytest.importorskip("zmq") + pub_ep = "inproc://test-kv-replay-pub2" + replay_ep = "inproc://test-kv-replay-router2" + pub = ZmqEventPublisher( + endpoint=pub_ep, replay_endpoint=replay_ep, buffer_steps=64 + ) + ctx = zmq.Context.instance() + sub = ctx.socket(zmq.SUB) + try: + sub.setsockopt(zmq.SUBSCRIBE, b"") + sub.connect(pub_ep) + for i in range(4): + pub.publish([BlockRemoved(block_hashes=[i])]) + received = 0 + polls = 50 + while received < 4 and polls > 0: + if sub.poll(timeout=200): + sub.recv_multipart() + received += 1 + polls -= 1 + assert received == 4 + + dealer = ctx.socket(zmq.DEALER) + dealer.connect(replay_ep) + dealer.send((2).to_bytes(8, "big")) # start_seq = 2 + replayed_seqs: list[int] = [] + replay_polls = 50 + while len(replayed_seqs) < 2 and replay_polls > 0: + if dealer.poll(timeout=200): + seq_bytes, _ = dealer.recv_multipart() + replayed_seqs.append(int.from_bytes(seq_bytes, "big")) + replay_polls -= 1 + dealer.close(linger=0) + assert replayed_seqs == [2, 3] + finally: + sub.close(linger=0) + pub.shutdown() + def test_publish_drops_oldest_on_overflow(self): # buffer_steps=1 + stopped sender => every publish past the first must # drop the oldest queued item and tick stats["dropped"].