-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[V1] [P/D] Add Support for KV Load Failure Recovery #19330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
simon-mo
merged 18 commits into
vllm-project:main
from
sdavidbd:feature/kv-load-failure-recovery
Sep 30, 2025
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
3f21e88
Add KV load failure recovery by rescheduling affected requests
c9ca75c
Add KV load failure recovery unit tests
092d9e6
Add KV load failure recovery example
32fe2e5
Add failure recovery support for async KV load
8a29e39
Fix unit tests after rebase
e540370
Refactor scheduler changes
4f008ee
Fix and refactor rollback handling of invalid sampled tokens
99ebaf3
Refactor: always return a set of failed block IDs from connectors
cd151ce
Implement get_block_ids_with_load_errors in MultiConnector
9c77ca8
Emphasize that KV load failure recovery does not currently support HMA
f6b62ce
Fix handling of async KV load failures
1b51d6a
Fix handling of progressive async KV load failures
3a391f5
Refactor total_affected_tokens calculation
4cd273c
Add note on reporting only loadable KV prefix in get_num_new_matched_β¦
51fbe2c
Rename recovered_req_ids to affected_req_ids for clarity
c299ff3
Refactor KV connector output emptiness check
3d41b47
Fix PR comments
46dcace
Fix rollback of invalid output tokens and generator state
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
30 changes: 30 additions & 0 deletions
30
examples/offline_inference/kv_load_failure_recovery/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# KV Load Failure Recovery Test | ||
|
||
This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`. | ||
|
||
It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output. | ||
|
||
## Files | ||
|
||
- `prefill_example.py` β performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`). | ||
- `decode_example.py` β performs the decode stage. Accepts: | ||
- `--simulate-failure`: simulates KV load failure using a custom connector. | ||
- `--async-load`: enables asynchronous KV loading mode. | ||
- `rogue_shared_storage_connector.py` β defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request. | ||
- `run.sh` β orchestrates the test: runs the prefill stage, then three decode stages: | ||
1. Normal decode (baseline). | ||
2. Decode with simulated sync KV load failure. | ||
3. Decode with simulated async KV load failure. | ||
|
||
Finally, it compares the output of the baseline with the recovered outputs to verify correctness. | ||
|
||
## How It Works | ||
|
||
- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector. | ||
- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode. | ||
- If recovery fails, the script prints a unified diff of the output mismatch and exits with error. | ||
|
||
## Usage | ||
|
||
```bash | ||
./run.sh |
85 changes: 85 additions & 0 deletions
85
examples/offline_inference/kv_load_failure_recovery/decode_example.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import argparse | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.config import KVTransferConfig | ||
|
||
|
||
def read_prompts(): | ||
"""Read prompts from prefill_output.txt""" | ||
prompts = [] | ||
try: | ||
with open("prefill_output.txt") as f: | ||
for line in f: | ||
prompts.append(line.strip()) | ||
print(f"Loaded {len(prompts)} prompts from prefill_output.txt") | ||
return prompts | ||
except FileNotFoundError: | ||
print("Error: prefill_output.txt file not found") | ||
exit(-1) | ||
|
||
|
||
def main(): | ||
prompts = read_prompts() | ||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--simulate-failure", action="store_true", help="Simulate KV load failure." | ||
) | ||
parser.add_argument( | ||
"--async-load", action="store_true", help="Simulate async KV load" | ||
) | ||
args = parser.parse_args() | ||
|
||
if args.simulate_failure: | ||
ktc = KVTransferConfig( | ||
kv_connector="RogueSharedStorageConnector", | ||
kv_role="kv_both", | ||
kv_connector_extra_config={ | ||
"shared_storage_path": "local_storage", | ||
"async_load": args.async_load, | ||
}, | ||
kv_connector_module_path="rogue_shared_storage_connector", | ||
) | ||
out_file = ( | ||
"async_decode_recovered_output.txt" | ||
if args.async_load | ||
else "sync_decode_recovered_output.txt" | ||
) | ||
else: | ||
ktc = KVTransferConfig( | ||
kv_connector="SharedStorageConnector", | ||
kv_role="kv_both", | ||
kv_connector_extra_config={ | ||
"shared_storage_path": "local_storage", | ||
}, | ||
) | ||
out_file = "decode_output.txt" | ||
|
||
llm = LLM( | ||
model="meta-llama/Llama-3.2-1B-Instruct", | ||
enforce_eager=True, | ||
gpu_memory_utilization=0.8, | ||
max_num_batched_tokens=64, | ||
max_num_seqs=16, | ||
kv_transfer_config=ktc, | ||
) | ||
|
||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
sep_str = "-" * 30 | ||
with open(out_file, "w", encoding="utf-8") as f: | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}" | ||
print(out_str) | ||
print(sep_str) | ||
f.write(out_str) | ||
f.write(sep_str) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
58 changes: 58 additions & 0 deletions
58
examples/offline_inference/kv_load_failure_recovery/prefill_example.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.config import KVTransferConfig | ||
|
||
|
||
def read_prompts(): | ||
context = "Hi " * 1000 | ||
context2 = "Hey " * 500 | ||
return [ | ||
context + "Hello, my name is", | ||
context + "The capital of France is", | ||
context2 + "Your name is", | ||
context2 + "The capital of China is", | ||
] | ||
|
||
|
||
def main(): | ||
prompts = read_prompts() | ||
|
||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) | ||
|
||
llm = LLM( | ||
model="meta-llama/Llama-3.2-1B-Instruct", | ||
enforce_eager=True, | ||
gpu_memory_utilization=0.8, | ||
kv_transfer_config=KVTransferConfig( | ||
kv_connector="SharedStorageConnector", | ||
kv_role="kv_both", | ||
kv_connector_extra_config={"shared_storage_path": "local_storage"}, | ||
), | ||
) # , max_model_len=2048, max_num_batched_tokens=2048) | ||
|
||
# 1ST generation (prefill instance) | ||
outputs = llm.generate( | ||
prompts, | ||
sampling_params, | ||
) | ||
|
||
new_prompts = [] | ||
print("-" * 30) | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
new_prompts.append(prompt + generated_text) | ||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") | ||
print("-" * 30) | ||
|
||
# Write new_prompts to prefill_output.txt | ||
with open("prefill_output.txt", "w") as f: | ||
for prompt in new_prompts: | ||
f.write(prompt + "\n") | ||
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
145 changes: 145 additions & 0 deletions
145
examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
# ruff: noqa: E501 | ||
import logging | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( | ||
KVConnectorMetadata, | ||
KVConnectorRole, | ||
) | ||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( | ||
SharedStorageConnector, | ||
SharedStorageConnectorMetadata, | ||
) | ||
from vllm.forward_context import ForwardContext | ||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks | ||
from vllm.v1.request import Request | ||
|
||
if TYPE_CHECKING: | ||
from vllm.v1.core.sched.output import SchedulerOutput | ||
|
||
logger = logging.getLogger() | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
@dataclass | ||
class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata): | ||
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict) | ||
|
||
@classmethod | ||
def from_base(cls, base: SharedStorageConnectorMetadata): | ||
return cls(requests=base.requests) | ||
|
||
|
||
class RogueSharedStorageConnector(SharedStorageConnector): | ||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): | ||
super().__init__(vllm_config=vllm_config, role=role) | ||
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config( | ||
"async_load", False | ||
) | ||
self._invalid_block_ids: set = None | ||
self._seen_requests: set = set() | ||
self._req_to_block_ids: dict[str, list[int]] = dict() | ||
|
||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: | ||
assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata) | ||
index, failed_request = next( | ||
( | ||
(i, x) | ||
for i, x in enumerate(connector_metadata.requests) | ||
if not x.is_store | ||
), | ||
(None, None), | ||
) | ||
if index is not None: | ||
del connector_metadata.requests[index] | ||
self._invalid_block_ids = set( | ||
( | ||
failed_request.slot_mapping[:: self._block_size] // self._block_size | ||
).tolist() | ||
) | ||
logger.info( | ||
"Simulating failure to load all KV blocks for the " | ||
"first load request. Total blocks: %d", | ||
len(self._invalid_block_ids), | ||
) | ||
super().bind_connector_metadata(connector_metadata) | ||
|
||
def clear_connector_metadata(self) -> None: | ||
self._invalid_block_ids = None | ||
super().clear_connector_metadata() | ||
|
||
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None: | ||
if self._async_load and forward_context.attn_metadata is None: | ||
# Bypass sanity check in super().start_load_kv | ||
forward_context.attn_metadata = "None" | ||
|
||
super().start_load_kv(forward_context, **kwargs) | ||
|
||
def get_finished( | ||
self, finished_req_ids: set[str] | ||
) -> tuple[Optional[set[str]], Optional[set[str]]]: | ||
if self._async_load: | ||
meta = self._get_connector_metadata() | ||
assert isinstance(meta, RogueSharedStorageConnectorMetadata) | ||
if meta.req_to_block_ids: | ||
return None, set(meta.req_to_block_ids) | ||
|
||
return None, None | ||
|
||
def get_block_ids_with_load_errors(self) -> set[int]: | ||
return self._invalid_block_ids | ||
|
||
def get_num_new_matched_tokens( | ||
self, | ||
request: Request, | ||
num_computed_tokens: int, | ||
) -> tuple[int, bool]: | ||
if request.request_id in self._seen_requests: | ||
return 0, False | ||
|
||
self._seen_requests.add(request.request_id) | ||
|
||
num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens) | ||
return num_tokens, self._async_load and num_tokens > 0 | ||
|
||
def update_state_after_alloc( | ||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int | ||
): | ||
""" | ||
Update KVConnector state after block allocation. | ||
|
||
If blocks were allocated, add to _requests_need_load, | ||
such that we load the KVs in the next forward pass. | ||
""" | ||
super().update_state_after_alloc(request, blocks, num_external_tokens) | ||
|
||
if num_external_tokens > 0: | ||
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0] | ||
|
||
def build_connector_meta( | ||
self, | ||
scheduler_output: "SchedulerOutput", | ||
) -> KVConnectorMetadata: | ||
if not self._async_load: | ||
base = super().build_connector_meta(scheduler_output) | ||
meta = RogueSharedStorageConnectorMetadata.from_base(base) | ||
else: | ||
meta = RogueSharedStorageConnectorMetadata() | ||
if self._requests_need_load: | ||
for req_id, request in self._requests_need_load.items(): | ||
meta.add_request( | ||
token_ids=request.prompt_token_ids, | ||
block_ids=self._req_to_block_ids[req_id], | ||
block_size=self._block_size, | ||
is_store=False, | ||
mm_hashes=[], | ||
) | ||
# Clear state | ||
self._requests_need_load.clear() | ||
meta.req_to_block_ids = self._req_to_block_ids | ||
self._req_to_block_ids = dict() | ||
return meta |
33 changes: 33 additions & 0 deletions
33
examples/offline_inference/kv_load_failure_recovery/run.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/bin/bash | ||
|
||
# Constants | ||
SHARED_STORAGE_DIR="local_storage" | ||
PREFILL_OUTPUT="prefill_output.txt" | ||
DECODE_OUTPUT="decode_output.txt" | ||
SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt" | ||
ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt" | ||
|
||
# Cleanup | ||
rm -rf "$SHARED_STORAGE_DIR" | ||
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT" | ||
|
||
# Run inference examples | ||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py | ||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py | ||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure | ||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load | ||
|
||
# Compare outputs | ||
if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then | ||
echo "β Outputs differ: sync recovery failed." | ||
diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" | ||
exit 1 | ||
fi | ||
|
||
if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then | ||
echo "β Outputs differ: async recovery failed." | ||
diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT" | ||
exit 1 | ||
fi | ||
|
||
echo "β Outputs match: recovery successful." |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be part of the test suite rather than under examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sdavidbd also wondering about this. Do we have equivalent of this also covered in the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do have similar logic covered in the unit tests β specifically async/sync loading and shared blocks in the sync case. That said, the unit tests focus on the recovery logic in isolation, while this example serves as an integration test: it exercises the full end-to-end flow and verifies overall correctness.