Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
48e63ba
[kv_transfer] ATOM standalone LMCache CPU/NVMe KV-offload connector (…
yhl-amd May 30, 2026
291b43f
Optimize LMCache offload reload path
yhl-amd May 31, 2026
7ec6015
Optimize LMCache offload chunked prefill path
yhl-amd Jun 1, 2026
77dbd02
WIP lmcache partial reload and HOL wake
yhl-amd Jun 1, 2026
5f76dd7
Revert "WIP lmcache partial reload and HOL wake"
yhl-amd Jun 1, 2026
895ecdb
Fix LMCache offload reload handoff
yhl-amd Jun 2, 2026
6974804
Reduce offload resume scheduler diff noise
yhl-amd Jun 2, 2026
af8ae03
Format scheduler with Black
yhl-amd Jun 2, 2026
e48edfa
Fix offload formatting and lint
yhl-amd Jun 2, 2026
bdbe0c1
Clarify parked offload prefill naming
yhl-amd Jun 2, 2026
da02bdd
Support max completion tokens in OpenAI API
yhl-amd Jun 2, 2026
d23a9f4
Add LMCache-compatible offload connector
yhl-amd Jun 3, 2026
b0e300e
Add fused chunk-major LMCache staging
yhl-amd Jun 3, 2026
7b39986
Add bounded LMCache staging with chunk2 default
yhl-amd Jun 4, 2026
5084d24
Remove obsolete offload staging fallback code
yhl-amd Jun 4, 2026
28dc7df
Remove unused offload staging sizing hooks
yhl-amd Jun 4, 2026
0901ea9
Split LMCache offload metadata and staging helpers
yhl-amd Jun 4, 2026
0cabc88
Merge LMCache metadata wrapper into offload metadata
yhl-amd Jun 4, 2026
86073f4
Replace native KV staging with Triton kernel
yhl-amd Jun 4, 2026
0044651
Refactor scheduler remote KV admission
yhl-amd Jun 4, 2026
889248b
Remove offload host fallback staging
yhl-amd Jun 4, 2026
53eb02a
Rename ATOM LMCache offload modules
yhl-amd Jun 4, 2026
47f9ca6
Remove obsolete offload staging switches
yhl-amd Jun 4, 2026
20f45cd
Align offload logging with ATOM defaults
yhl-amd Jun 4, 2026
4d3bb89
Clarify offload connector module names
yhl-amd Jun 4, 2026
af36b20
Require fused chunk-major offload staging
yhl-amd Jun 4, 2026
069996b
Simplify LMCache offload staging buffer
yhl-amd Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 79 additions & 4 deletions atom/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,43 @@ def _coerce_n(requested_n: Optional[int], temperature: Optional[float]) -> int:
return n


def _validate_context_length(
num_prompt_tokens: int,
max_tokens: int,
max_model_len: Optional[int],
) -> None:
if max_model_len is None:
return

requested_output_tokens = max(0, int(max_tokens or 0))
total_tokens = int(num_prompt_tokens) + requested_output_tokens
if total_tokens <= int(max_model_len):
return

raise ValueError(
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {requested_output_tokens} output tokens and "
f"your prompt contains at least {num_prompt_tokens} input tokens, for "
f"a total of at least {total_tokens} tokens. Please reduce the length "
f"of the input prompt or the number of requested output tokens."
)


def _get_engine_max_model_len() -> Optional[int]:
config = getattr(engine, "config", None)
if config is None:
config = getattr(getattr(engine, "io_processor", None), "config", None)
return getattr(config, "max_model_len", None)


def _validate_sequence_context_length(seq) -> None:
_validate_context_length(
seq.num_prompt_tokens,
seq.max_tokens,
_get_engine_max_model_len(),
)


def _has_multimodal_content(messages: List[Any]) -> bool:
for message in messages:
content = getattr(message, "content", None)
Expand Down Expand Up @@ -369,6 +406,11 @@ def do_preprocess():
)

seq = await loop.run_in_executor(None, do_preprocess)
try:
_validate_sequence_context_length(seq)
except Exception:
engine.io_processor.requests.pop(seq.id, None)
raise
engine.core_mgr.add_request([seq])

while True:
Expand Down Expand Up @@ -454,6 +496,11 @@ def do_preprocess():
)

seq = await loop.run_in_executor(None, do_preprocess)
try:
_validate_sequence_context_length(seq)
except Exception:
engine.io_processor.requests.pop(seq.id, None)
raise
engine.core_mgr.add_request([seq])

while True:
Expand Down Expand Up @@ -553,6 +600,12 @@ def do_preprocess():
)

seqs = await loop.run_in_executor(None, do_preprocess)
try:
_validate_sequence_context_length(seqs[0])
except Exception:
for seq in seqs:
engine.io_processor.requests.pop(seq.id, None)
raise
engine.core_mgr.add_request(seqs)
num_tokens_input = seqs[0].num_prompt_tokens

Expand Down Expand Up @@ -649,7 +702,18 @@ def do_preprocess():
_seq_id_to_request_id[seq.id] = request_id
return seq

seq = await executor_loop.run_in_executor(None, do_preprocess)
seq = None
try:
seq = await executor_loop.run_in_executor(None, do_preprocess)
_validate_sequence_context_length(seq)
except Exception:
_stream_queues.pop(request_id, None)
_stream_loops.pop(request_id, None)
_request_start_times.pop(request_id, None)
if seq is not None:
_seq_id_to_request_id.pop(seq.id, None)
engine.io_processor.requests.pop(seq.id, None)
raise
seq_id = seq.id

logger.info(f"API: Created request_id={request_id}, seq_id={seq_id}")
Expand Down Expand Up @@ -723,7 +787,18 @@ def do_preprocess():
_seq_id_to_request_id[seq.id] = request_id
return seqs

seqs = await executor_loop.run_in_executor(None, do_preprocess)
seqs = []
try:
seqs = await executor_loop.run_in_executor(None, do_preprocess)
_validate_sequence_context_length(seqs[0])
except Exception:
_stream_queues.pop(request_id, None)
_stream_loops.pop(request_id, None)
_request_start_times.pop(request_id, None)
for seq in seqs:
_seq_id_to_request_id.pop(seq.id, None)
engine.io_processor.requests.pop(seq.id, None)
raise
seq_ids = [seq.id for seq in seqs]
logger.info(
f"API: Created fan-out request_id={request_id}, n={n}, seq_ids={seq_ids}"
Expand Down Expand Up @@ -802,7 +877,7 @@ async def chat_completions(request: ChatCompletionRequest):
effective_n = _coerce_n(request.n, request.temperature)
sampling_params = _build_sampling_params(
temperature=request.temperature,
max_tokens=request.max_tokens,
max_tokens=request.get_max_tokens(),
stop_strings=request.stop,
ignore_eos=request.ignore_eos,
top_k=request.top_k,
Expand Down Expand Up @@ -931,7 +1006,7 @@ async def completions(request: CompletionRequest):
effective_n = _coerce_n(request.n, request.temperature)
sampling_params = _build_sampling_params(
temperature=request.temperature,
max_tokens=request.max_tokens,
max_tokens=request.get_max_tokens(),
stop_strings=request.stop,
ignore_eos=request.ignore_eos,
top_k=request.top_k,
Expand Down
18 changes: 18 additions & 0 deletions atom/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[int] = DEFAULT_TOP_K
top_p: Optional[float] = DEFAULT_TOP_P
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
max_completion_tokens: Optional[int] = None
stop: Optional[List[str]] = None
ignore_eos: Optional[bool] = False
stream: Optional[bool] = False
Expand All @@ -90,6 +91,14 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = 0.0
n: Optional[int] = 1

def get_max_tokens(self) -> int:
"""Return the effective generation cap for OpenAI chat requests."""
if self.max_completion_tokens is not None:
return self.max_completion_tokens
if self.max_tokens is not None:
return self.max_tokens
return DEFAULT_MAX_TOKENS

def get_messages(self) -> List[ChatMessage]:
"""Get messages from either 'messages' or 'prompt' field."""
if self.messages is not None:
Expand All @@ -111,13 +120,22 @@ class CompletionRequest(BaseModel):
top_k: Optional[int] = DEFAULT_TOP_K
top_p: Optional[float] = DEFAULT_TOP_P
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
max_completion_tokens: Optional[int] = None
stop: Optional[List[str]] = None
ignore_eos: Optional[bool] = False
stream: Optional[bool] = False
# Optional KV-transfer metadata for P/D disaggregation.
kv_transfer_params: Optional[Dict[str, Any]] = None
n: Optional[int] = 1

def get_max_tokens(self) -> int:
"""Return the effective generation cap for completion requests."""
if self.max_completion_tokens is not None:
return self.max_completion_tokens
if self.max_tokens is not None:
return self.max_tokens
return DEFAULT_MAX_TOKENS


# ============================================================================
# Response Models
Expand Down
68 changes: 64 additions & 4 deletions atom/kv_transfer/disaggregation/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import logging

from atom.kv_transfer.disaggregation.types import KVConnectorOutput
from atom.kv_transfer.disaggregation.types import KVConnectorOutput, ReqId
from atom.kv_transfer.offload.trace import offload_trace

logger = logging.getLogger("atom")

Expand Down Expand Up @@ -48,8 +49,10 @@ def __init__(self, world_size: int = 8) -> None:
if world_size <= 0:
raise ValueError(f"world_size must be positive, got {world_size}")
self._world_size = world_size
self._seen_sending: dict[str, set[int]] = {}
self._seen_recving: dict[str, set[int]] = {}
self._seen_sending: dict[ReqId, set[int]] = {}
self._seen_recving: dict[ReqId, set[int]] = {}
self._seen_recv_failed: dict[ReqId, set[int]] = {}
self._seen_saving: dict[ReqId, set[int]] = {}

@property
def world_size(self) -> int:
Expand All @@ -76,34 +79,91 @@ def aggregate(self, worker_outputs: list[KVConnectorOutput]) -> KVConnectorOutpu
if wo.finished_recving:
for rid in wo.finished_recving:
self._seen_recving.setdefault(rid, set()).add(worker_idx)
offload_trace(
"aggregator_worker_recv_done",
worker=worker_idx,
req=rid,
seen=len(self._seen_recving[rid]),
world=self._world_size,
)
if wo.failed_recving:
for rid in wo.failed_recving:
self._seen_recv_failed.setdefault(rid, set()).add(worker_idx)
offload_trace(
"aggregator_worker_recv_failed",
worker=worker_idx,
req=rid,
seen=len(self._seen_recv_failed[rid]),
world=self._world_size,
)
if wo.finished_saving:
for rid in wo.finished_saving:
self._seen_saving.setdefault(rid, set()).add(worker_idx)

done_sending = {
rid
for rid, workers in self._seen_sending.items()
if len(workers) >= self._world_size
}
failed_recving = set()
recv_ids = set(self._seen_recving) | set(self._seen_recv_failed)
for rid in recv_ids:
done_workers = self._seen_recving.get(rid, set())
failed_workers = self._seen_recv_failed.get(rid, set())
if (
failed_workers
and len(done_workers | failed_workers) >= self._world_size
):
failed_recving.add(rid)
done_recving = {
rid
for rid, workers in self._seen_recving.items()
if len(workers) >= self._world_size and rid not in failed_recving
}
done_saving = {
rid
for rid, workers in self._seen_saving.items()
if len(workers) >= self._world_size
}

for rid in done_sending:
del self._seen_sending[rid]
for rid in done_recving:
del self._seen_recving[rid]
self._seen_recv_failed.pop(rid, None)
for rid in failed_recving:
self._seen_recving.pop(rid, None)
self._seen_recv_failed.pop(rid, None)
for rid in done_saving:
del self._seen_saving[rid]

if done_recving or failed_recving or done_saving:
offload_trace(
"aggregator_done",
recv=sorted(done_recving),
failed=sorted(failed_recving),
saving=sorted(done_saving),
)

return KVConnectorOutput(
finished_sending=done_sending,
finished_recving=done_recving,
failed_recving=failed_recving,
finished_saving=done_saving,
)

def reset(self) -> None:
"""Clear all internal tracking state."""
self._seen_sending.clear()
self._seen_recving.clear()
self._seen_recv_failed.clear()
self._seen_saving.clear()

@property
def pending_count(self) -> tuple[int, int]:
"""Return ``(num_pending_sending, num_pending_recving)``."""
return len(self._seen_sending), len(self._seen_recving)
return (
len(self._seen_sending),
len(set(self._seen_recving) | set(self._seen_recv_failed))
+ len(self._seen_saving),
)
9 changes: 6 additions & 3 deletions atom/kv_transfer/disaggregation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import ABC, abstractmethod
from typing import Any

from atom.kv_transfer.disaggregation.types import ConnectorMetadata
from atom.kv_transfer.disaggregation.types import ConnectorMetadata, KVConnectorOutput


class KVConnectorBase(ABC):
Expand All @@ -48,8 +48,11 @@ def start_load_kv(self, metadata: ConnectorMetadata) -> None:
...

@abstractmethod
def get_finished(self) -> tuple[set, set]:
"""Return ``(done_sending, done_recving)`` request ID sets.
def get_finished(self) -> tuple[set, set] | KVConnectorOutput:
"""Return transfer completion status.

Older connectors may return ``(done_sending, done_recving)``. Connectors
that need richer semantics can return :class:`KVConnectorOutput`.

Called by the worker each engine step to report transfer status.
"""
Expand Down
9 changes: 9 additions & 0 deletions atom/kv_transfer/disaggregation/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,12 @@ def create_connector(
scheduler_module="atom.kv_transfer.disaggregation.mooncake.mooncake_connector",
scheduler_class="MooncakeConnectorScheduler",
)


# ATOM standalone CPU/NVMe KV offload backend (registers "lmcache_offload").
# Import is lightweight (offload/__init__ only records module paths as strings;
# the connector module is imported lazily by create_connector when selected).
try:
import atom.kv_transfer.offload # noqa: F401,E402
except Exception as _e: # pragma: no cover - offload optional (needs lmcache)
logger.debug("lmcache_offload backend not registered: %s", _e)
21 changes: 16 additions & 5 deletions atom/kv_transfer/disaggregation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# ---------------------------------------------------------------------------

EngineId = str
ReqId = str
ReqId = str | int
TransferId = int

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -59,22 +59,33 @@ class KVConnectorOutput:
Attributes:
finished_sending: Request IDs whose KV send completed on this worker.
finished_recving: Request IDs whose KV receive completed on this worker.
failed_recving: Request IDs whose KV receive failed on this worker.
finished_saving: Request IDs whose local fire-and-forget save completed.
expected_finished_count: How many finished notifications should be
expected per request (used by the aggregator).
"""

finished_sending: set[str] = field(default_factory=set)
finished_recving: set[str] = field(default_factory=set)
finished_sending: set[ReqId] = field(default_factory=set)
finished_recving: set[ReqId] = field(default_factory=set)
failed_recving: set[ReqId] = field(default_factory=set)
finished_saving: set[ReqId] = field(default_factory=set)
expected_finished_count: int = 0

def is_empty(self) -> bool:
"""Return True if no transfers finished on this worker."""
return not self.finished_sending and not self.finished_recving
return (
not self.finished_sending
and not self.finished_recving
and not self.failed_recving
and not self.finished_saving
)

def __repr__(self) -> str:
return (
f"KVConnectorOutput(sending={self.finished_sending}, "
f"recving={self.finished_recving})"
f"recving={self.finished_recving}, "
f"failed_recving={self.failed_recving}, "
f"finished_saving={self.finished_saving})"
)


Expand Down
Loading