Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,9 @@ class KVEventsConfig:
publisher: str = "zmq" # "null" | "zmq"
endpoint: str = "tcp://127.0.0.1:5557"
topic: str = ""
# ROUTER endpoint subscribers use to request replay of missed batches by
# sequence number. Empty string keeps replay disabled (PUB-only).
replay_endpoint: str = ""
# ZMQ high-water-mark on the PUB socket (0 = unlimited).
hwm: int = 0
# Bounded in-process queue between scheduler and sender thread. When full,
Expand All @@ -972,6 +975,7 @@ def from_env(cls) -> "KVEventsConfig":
publisher=envs.ATOM_KV_EVENTS_PUBLISHER,
endpoint=envs.ATOM_KV_EVENTS_ENDPOINT,
topic=envs.ATOM_KV_EVENTS_TOPIC,
replay_endpoint=envs.ATOM_KV_EVENTS_REPLAY_ENDPOINT,
hwm=envs.ATOM_KV_EVENTS_HWM,
buffer_steps=envs.ATOM_KV_EVENTS_BUFFER_STEPS,
)
Expand Down
87 changes: 78 additions & 9 deletions atom/distributed/kv_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from __future__ import annotations

import itertools
import logging
import queue
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable
from typing import Any, Final

Expand Down Expand Up @@ -50,6 +52,10 @@ class BlockStored(KVCacheEvent):
# Reserved wire slots; emitted as None until hybrid-cache wiring lands.
kv_cache_spec_kind: str | None = None
kv_cache_spec_sliding_window: int | None = None
# ATOM extension (trailing, so strict vLLM array_like consumers ignore it):
# sequence position of the first token of the first block in this run.
# With block_size, block i covers [token_offset + i*block_size, +block_size).
token_offset: int | None = None


class BlockRemoved(KVCacheEvent):
Expand Down Expand Up @@ -134,10 +140,18 @@ class ZmqEventPublisher(EventPublisher):
batch is dropped — KV events are advisory and a missed eviction is
cheaper than stalling inference.

With `topic=""` (default) each message is a single msgpack frame; with a
non-empty `topic` the publisher sends two-frame multipart messages
(`[topic, payload]`), so consumers must use `recv_multipart()` to read
the payload.
Every message is a three-frame multipart `[topic, seq, payload]`, where
`topic` is the (possibly empty) subscription key, `seq` is a monotonic
8-byte big-endian batch counter, and `payload` is the msgpack-encoded
EventBatch. Consumers must use `recv_multipart()`. The sequence number
lets a subscriber detect batches it missed (e.g. a slow/late SUB that the
PUB socket dropped at the transport level) and, when a `replay_endpoint`
is configured, request them back from the in-memory replay buffer.

Note: `seq` is assigned at send time, so it counts only batches that left
the sender. Batches dropped from the internal queue on overflow (slow
encoder) are advisory losses tracked in `stats['dropped']`; they consume
no sequence number and are not replayable.
"""

def __init__(
Expand All @@ -147,6 +161,7 @@ def __init__(
topic: str = "",
hwm: int = 0,
buffer_steps: int = 10_000,
replay_endpoint: str = "",
data_parallel_rank: int | None = None,
encoder: msgspec.msgpack.Encoder | None = None,
) -> None:
Expand All @@ -170,8 +185,21 @@ def __init__(
self._socket.bind(endpoint)
self._zmq_error_cls = zmq.ZMQError # captured so _run doesn't re-import

# Optional replay: a ROUTER socket + ring buffer of recently-sent
# batches. A subscriber that detects a seq gap can request everything
# from a start sequence number and get the buffered batches back. The
# ROUTER is created here but used only by the sender thread.
self._replay = None
self._replay_buffer: deque[tuple[int, bytes, bytes]] | None = None
if replay_endpoint:
self._replay = ctx.socket(zmq.ROUTER)
self._replay.bind(replay_endpoint)
self._replay_buffer = deque(maxlen=buffer_steps)

self._seq_gen = itertools.count()
self._drops = 0
self._sent = 0
self._replayed = 0
self._encode_errors = 0
self._closing = False
self._lock = threading.Lock()
Expand Down Expand Up @@ -239,30 +267,69 @@ def shutdown(self) -> None:
self._socket.close(linger=linger)
except Exception: # pragma: no cover
pass
if self._replay is not None:
try:
self._replay.close(linger=0)
except Exception: # pragma: no cover
pass

# --- internal ---
def _run(self) -> None:
# Poll the replay socket between sends. When replay is disabled the
# queue.get() blocks (timeout=None); when enabled it wakes periodically
# so replay requests are serviced even while no events are flowing.
get_timeout = 0.05 if self._replay is not None else None
while True:
item = self._queue.get()
if self._replay is not None and self._replay.poll(0):
try:
self._service_replay()
except Exception: # pragma: no cover - replay is non-critical
logger.exception("KV event replay request failed")
try:
item = self._queue.get(timeout=get_timeout)
except queue.Empty:
continue
if item is None:
return
try:
if self._topic_bytes:
self._socket.send_multipart([self._topic_bytes, item])
else:
self._socket.send(item)
seq = next(self._seq_gen)
seq_bytes = seq.to_bytes(8, "big")
self._socket.send_multipart([self._topic_bytes, seq_bytes, item])
Comment on lines +295 to +297
if self._replay_buffer is not None:
self._replay_buffer.append((seq, seq_bytes, item))
with self._lock:
self._sent += 1
except self._zmq_error_cls: # pragma: no cover - socket closed
return

def _service_replay(self) -> None:
"""Answer a pending replay request: resend every buffered batch with
seq >= the requested start sequence. Request frame is
`[client_id, (delim,) start_seq]`; we echo the routing prefix back."""
frames = self._replay.recv_multipart()
if len(frames) < 2:
logger.warning("KV event replay: malformed request %r", frames)
return
try:
start_seq = int.from_bytes(frames[-1], "big")
except Exception:
logger.warning("KV event replay: bad start_seq %r", frames[-1])
return
prefix = frames[:-1] # [client_id] or [client_id, empty_delim]
for seq, seq_bytes, payload in list(self._replay_buffer or ()):
if seq >= start_seq:
self._replay.send_multipart([*prefix, seq_bytes, payload])
with self._lock:
self._replayed += 1

# Test/diagnostic hooks.
@property
def stats(self) -> dict[str, int]:
with self._lock:
return {
"sent": self._sent,
"dropped": self._drops,
"replayed": self._replayed,
"encode_errors": self._encode_errors,
}

Expand All @@ -275,6 +342,7 @@ def make_publisher(
topic: str = "",
hwm: int = 0,
buffer_steps: int = 10_000,
replay_endpoint: str = "",
data_parallel_rank: int | None = None,
) -> EventPublisher:
"""Construct a publisher from plain-config args. Returns `NullEventPublisher`
Expand All @@ -287,6 +355,7 @@ def make_publisher(
topic=topic,
hwm=hwm,
buffer_steps=buffer_steps,
replay_endpoint=replay_endpoint,
data_parallel_rank=data_parallel_rank,
)
raise ValueError(f"unknown KV event publisher: {publisher_kind!r}")
16 changes: 14 additions & 2 deletions atom/model_engine/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,21 @@ def _make_block_stored(
parent: int | None,
block_size: int,
medium: str = MEDIUM_GPU,
token_offset: int | None = None,
) -> BlockStored:
"""Construct a BlockStored event from a coalesced run of new blocks."""
"""Construct a BlockStored event from a coalesced run of new blocks.

`token_offset` is the sequence position of the first token of the run's
first block, so consumers can map block i to
`[token_offset + i*block_size, token_offset + (i+1)*block_size)`.
"""
return BlockStored(
block_hashes=hashes,
parent_block_hash=parent,
token_ids=tokens,
block_size=block_size,
medium=medium,
token_offset=token_offset,
)


Expand Down Expand Up @@ -251,6 +258,7 @@ def hash_blocks(self, seq: Sequence, num_new_tokens: int) -> None:
store_run_tokens,
store_run_parent,
self.block_size,
token_offset=start * self.block_size,
)
)

Expand Down Expand Up @@ -326,11 +334,14 @@ def record_remote_store(
block_hashes: list[int],
token_ids: list[int],
parent_block_hash: int | None = None,
token_offset: int | None = None,
) -> None:
"""Emit a BlockStored(medium=REMOTE) for blocks received from a remote
KV transfer producer (Mooncake/MoriIO decode side). Called by the
KVConnector worker once the transfer completes so external KV-cache
consumers (LMCache, etc.) can track remote-resident blocks."""
consumers (LMCache, etc.) can track remote-resident blocks.

`token_offset` is the sequence position of the first remote block."""
if self._event_log is None or not block_hashes:
return
self._event_log.append(
Expand All @@ -340,5 +351,6 @@ def record_remote_store(
parent_block_hash,
self.block_size,
medium=MEDIUM_REMOTE,
token_offset=token_offset,
)
)
2 changes: 2 additions & 0 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def __init__(self, config: Config):
topic=kv_events_cfg.topic,
hwm=kv_events_cfg.hwm,
buffer_steps=kv_events_cfg.buffer_steps,
replay_endpoint=kv_events_cfg.replay_endpoint,
data_parallel_rank=dp_rank,
)
logger.info(
Expand Down Expand Up @@ -1334,6 +1335,7 @@ def _update_waiting_for_remote_kv(self, seq: Sequence) -> bool:
block_hashes=remote_hashes,
token_ids=remote_tokens,
parent_block_hash=parent_block_hash,
token_offset=num_cached_blocks * bm.block_size,
)
return True

Expand Down
4 changes: 4 additions & 0 deletions atom/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@
"ATOM_KV_EVENTS_ENDPOINT", "tcp://127.0.0.1:5557"
),
"ATOM_KV_EVENTS_TOPIC": lambda: os.getenv("ATOM_KV_EVENTS_TOPIC", ""),
# ROUTER endpoint for the replay socket; empty string disables replay.
"ATOM_KV_EVENTS_REPLAY_ENDPOINT": lambda: os.getenv(
"ATOM_KV_EVENTS_REPLAY_ENDPOINT", ""
),
"ATOM_KV_EVENTS_HWM": lambda: int(os.getenv("ATOM_KV_EVENTS_HWM", "0") or "0"),
"ATOM_KV_EVENTS_BUFFER_STEPS": lambda: int(
os.getenv("ATOM_KV_EVENTS_BUFFER_STEPS", "10000") or "10000"
Expand Down
Loading