Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/offline_inference/kv_load_failure_recovery/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# KV Load Failure Recovery Test

This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`.

It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output.

## Files

- `prefill_example.py` – performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`).
- `decode_example.py` – performs the decode stage. Accepts:
- `--simulate-failure`: simulates KV load failure using a custom connector.
- `--async-load`: enables asynchronous KV loading mode.
- `rogue_shared_storage_connector.py` – defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request.
- `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages:
1. Normal decode (baseline).
2. Decode with simulated sync KV load failure.
3. Decode with simulated async KV load failure.

Finally, it compares the output of the baseline with the recovered outputs to verify correctness.

## How It Works

- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector.
- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode.
- If recovery fails, the script prints a unified diff of the output mismatch and exits with error.

## Usage

```bash
./run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig


def read_prompts():
"""Read prompts from prefill_output.txt"""
prompts = []
try:
with open("prefill_output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
return prompts
except FileNotFoundError:
print("Error: prefill_output.txt file not found")
exit(-1)


def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

parser = argparse.ArgumentParser()
parser.add_argument(
"--simulate-failure", action="store_true", help="Simulate KV load failure."
)
parser.add_argument(
"--async-load", action="store_true", help="Simulate async KV load"
)
args = parser.parse_args()

if args.simulate_failure:
ktc = KVTransferConfig(
kv_connector="RogueSharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
"async_load": args.async_load,
},
kv_connector_module_path="rogue_shared_storage_connector",
)
out_file = (
"async_decode_recovered_output.txt"
if args.async_load
else "sync_decode_recovered_output.txt"
)
else:
ktc = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
},
)
out_file = "decode_output.txt"

llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=ktc,
)

outputs = llm.generate(prompts, sampling_params)

sep_str = "-" * 30
with open(out_file, "w", encoding="utf-8") as f:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
print(out_str)
print(sep_str)
f.write(out_str)
f.write(sep_str)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig


def read_prompts():
context = "Hi " * 1000
context2 = "Hey " * 500
return [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]


def main():
prompts = read_prompts()

sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)

llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
),
) # , max_model_len=2048, max_num_batched_tokens=2048)

# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)

new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)

# Write new_prompts to prefill_output.txt
with open("prefill_output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be part of the test suite rather than under examples?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdavidbd also wondering about this. Do we have equivalent of this also covered in the tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do have similar logic covered in the unit tests β€” specifically async/sync loading and shared blocks in the sync case. That said, the unit tests focus on the recovery logic in isolation, while this example serves as an integration test: it exercises the full end-to-end flow and verifies overall correctness.

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
SharedStorageConnector,
SharedStorageConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput

logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)


@dataclass
class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata):
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)

@classmethod
def from_base(cls, base: SharedStorageConnectorMetadata):
return cls(requests=base.requests)


class RogueSharedStorageConnector(SharedStorageConnector):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
"async_load", False
)
self._invalid_block_ids: set = None
self._seen_requests: set = set()
self._req_to_block_ids: dict[str, list[int]] = dict()

def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata)
index, failed_request = next(
(
(i, x)
for i, x in enumerate(connector_metadata.requests)
if not x.is_store
),
(None, None),
)
if index is not None:
del connector_metadata.requests[index]
self._invalid_block_ids = set(
(
failed_request.slot_mapping[:: self._block_size] // self._block_size
).tolist()
)
logger.info(
"Simulating failure to load all KV blocks for the "
"first load request. Total blocks: %d",
len(self._invalid_block_ids),
)
super().bind_connector_metadata(connector_metadata)

def clear_connector_metadata(self) -> None:
self._invalid_block_ids = None
super().clear_connector_metadata()

def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
if self._async_load and forward_context.attn_metadata is None:
# Bypass sanity check in super().start_load_kv
forward_context.attn_metadata = "None"

super().start_load_kv(forward_context, **kwargs)

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if self._async_load:
meta = self._get_connector_metadata()
assert isinstance(meta, RogueSharedStorageConnectorMetadata)
if meta.req_to_block_ids:
return None, set(meta.req_to_block_ids)

return None, None

def get_block_ids_with_load_errors(self) -> set[int]:
return self._invalid_block_ids

def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int, bool]:
if request.request_id in self._seen_requests:
return 0, False

self._seen_requests.add(request.request_id)

num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
return num_tokens, self._async_load and num_tokens > 0

def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
"""
Update KVConnector state after block allocation.

If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
super().update_state_after_alloc(request, blocks, num_external_tokens)

if num_external_tokens > 0:
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]

def build_connector_meta(
self,
scheduler_output: "SchedulerOutput",
) -> KVConnectorMetadata:
if not self._async_load:
base = super().build_connector_meta(scheduler_output)
meta = RogueSharedStorageConnectorMetadata.from_base(base)
else:
meta = RogueSharedStorageConnectorMetadata()
if self._requests_need_load:
for req_id, request in self._requests_need_load.items():
meta.add_request(
token_ids=request.prompt_token_ids,
block_ids=self._req_to_block_ids[req_id],
block_size=self._block_size,
is_store=False,
mm_hashes=[],
)
# Clear state
self._requests_need_load.clear()
meta.req_to_block_ids = self._req_to_block_ids
self._req_to_block_ids = dict()
return meta
33 changes: 33 additions & 0 deletions examples/offline_inference/kv_load_failure_recovery/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash

# Constants
SHARED_STORAGE_DIR="local_storage"
PREFILL_OUTPUT="prefill_output.txt"
DECODE_OUTPUT="decode_output.txt"
SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt"
ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"

# Cleanup
rm -rf "$SHARED_STORAGE_DIR"
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"

# Run inference examples
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load

# Compare outputs
if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then
echo "❌ Outputs differ: sync recovery failed."
diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"
exit 1
fi

if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then
echo "❌ Outputs differ: async recovery failed."
diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
exit 1
fi

echo "βœ… Outputs match: recovery successful."
Loading
Loading