diff --git a/examples/offline_inference/kv_load_failure_recovery/README.md b/examples/offline_inference/kv_load_failure_recovery/README.md new file mode 100644 index 000000000000..230a16812b25 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/README.md @@ -0,0 +1,30 @@ +# KV Load Failure Recovery Test + +This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`. + +It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output. + +## Files + +- `prefill_example.py` – performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`). +- `decode_example.py` – performs the decode stage. Accepts: + - `--simulate-failure`: simulates KV load failure using a custom connector. + - `--async-load`: enables asynchronous KV loading mode. +- `rogue_shared_storage_connector.py` – defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request. +- `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages: + 1. Normal decode (baseline). + 2. Decode with simulated sync KV load failure. + 3. Decode with simulated async KV load failure. + + Finally, it compares the output of the baseline with the recovered outputs to verify correctness. + +## How It Works + +- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector. +- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode. +- If recovery fails, the script prints a unified diff of the output mismatch and exits with error. + +## Usage + +```bash +./run.sh diff --git a/examples/offline_inference/kv_load_failure_recovery/decode_example.py b/examples/offline_inference/kv_load_failure_recovery/decode_example.py new file mode 100644 index 000000000000..69523f56eace --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/decode_example.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def read_prompts(): + """Read prompts from prefill_output.txt""" + prompts = [] + try: + with open("prefill_output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from prefill_output.txt") + return prompts + except FileNotFoundError: + print("Error: prefill_output.txt file not found") + exit(-1) + + +def main(): + prompts = read_prompts() + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--simulate-failure", action="store_true", help="Simulate KV load failure." + ) + parser.add_argument( + "--async-load", action="store_true", help="Simulate async KV load" + ) + args = parser.parse_args() + + if args.simulate_failure: + ktc = KVTransferConfig( + kv_connector="RogueSharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage", + "async_load": args.async_load, + }, + kv_connector_module_path="rogue_shared_storage_connector", + ) + out_file = ( + "async_decode_recovered_output.txt" + if args.async_load + else "sync_decode_recovered_output.txt" + ) + else: + ktc = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage", + }, + ) + out_file = "decode_output.txt" + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=ktc, + ) + + outputs = llm.generate(prompts, sampling_params) + + sep_str = "-" * 30 + with open(out_file, "w", encoding="utf-8") as f: + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}" + print(out_str) + print(sep_str) + f.write(out_str) + f.write(sep_str) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/kv_load_failure_recovery/prefill_example.py b/examples/offline_inference/kv_load_failure_recovery/prefill_example.py new file mode 100644 index 000000000000..047b81c82df5 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/prefill_example.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def read_prompts(): + context = "Hi " * 1000 + context2 = "Hey " * 500 + return [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", + ] + + +def main(): + prompts = read_prompts() + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ), + ) # , max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate( + prompts, + sampling_params, + ) + + new_prompts = [] + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + # Write new_prompts to prefill_output.txt + with open("prefill_output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + print(f"Saved {len(new_prompts)} prompts to prefill_output.txt") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py new file mode 100644 index 000000000000..0abe7d161261 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Optional + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( + SharedStorageConnector, + SharedStorageConnectorMetadata, +) +from vllm.forward_context import ForwardContext +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.request import Request + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO) + + +@dataclass +class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata): + req_to_block_ids: dict[str, set[int]] = field(default_factory=dict) + + @classmethod + def from_base(cls, base: SharedStorageConnectorMetadata): + return cls(requests=base.requests) + + +class RogueSharedStorageConnector(SharedStorageConnector): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._async_load = vllm_config.kv_transfer_config.get_from_extra_config( + "async_load", False + ) + self._invalid_block_ids: set = None + self._seen_requests: set = set() + self._req_to_block_ids: dict[str, list[int]] = dict() + + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata) + index, failed_request = next( + ( + (i, x) + for i, x in enumerate(connector_metadata.requests) + if not x.is_store + ), + (None, None), + ) + if index is not None: + del connector_metadata.requests[index] + self._invalid_block_ids = set( + ( + failed_request.slot_mapping[:: self._block_size] // self._block_size + ).tolist() + ) + logger.info( + "Simulating failure to load all KV blocks for the " + "first load request. Total blocks: %d", + len(self._invalid_block_ids), + ) + super().bind_connector_metadata(connector_metadata) + + def clear_connector_metadata(self) -> None: + self._invalid_block_ids = None + super().clear_connector_metadata() + + def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None: + if self._async_load and forward_context.attn_metadata is None: + # Bypass sanity check in super().start_load_kv + forward_context.attn_metadata = "None" + + super().start_load_kv(forward_context, **kwargs) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if self._async_load: + meta = self._get_connector_metadata() + assert isinstance(meta, RogueSharedStorageConnectorMetadata) + if meta.req_to_block_ids: + return None, set(meta.req_to_block_ids) + + return None, None + + def get_block_ids_with_load_errors(self) -> set[int]: + return self._invalid_block_ids + + def get_num_new_matched_tokens( + self, + request: Request, + num_computed_tokens: int, + ) -> tuple[int, bool]: + if request.request_id in self._seen_requests: + return 0, False + + self._seen_requests.add(request.request_id) + + num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens) + return num_tokens, self._async_load and num_tokens > 0 + + def update_state_after_alloc( + self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int + ): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + super().update_state_after_alloc(request, blocks, num_external_tokens) + + if num_external_tokens > 0: + self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0] + + def build_connector_meta( + self, + scheduler_output: "SchedulerOutput", + ) -> KVConnectorMetadata: + if not self._async_load: + base = super().build_connector_meta(scheduler_output) + meta = RogueSharedStorageConnectorMetadata.from_base(base) + else: + meta = RogueSharedStorageConnectorMetadata() + if self._requests_need_load: + for req_id, request in self._requests_need_load.items(): + meta.add_request( + token_ids=request.prompt_token_ids, + block_ids=self._req_to_block_ids[req_id], + block_size=self._block_size, + is_store=False, + mm_hashes=[], + ) + # Clear state + self._requests_need_load.clear() + meta.req_to_block_ids = self._req_to_block_ids + self._req_to_block_ids = dict() + return meta diff --git a/examples/offline_inference/kv_load_failure_recovery/run.sh b/examples/offline_inference/kv_load_failure_recovery/run.sh new file mode 100755 index 000000000000..53fe2385d46d --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/run.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Constants +SHARED_STORAGE_DIR="local_storage" +PREFILL_OUTPUT="prefill_output.txt" +DECODE_OUTPUT="decode_output.txt" +SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt" +ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt" + +# Cleanup +rm -rf "$SHARED_STORAGE_DIR" +rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT" + +# Run inference examples +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load + +# Compare outputs +if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then + echo "❌ Outputs differ: sync recovery failed." + diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" + exit 1 +fi + +if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then + echo "❌ Outputs differ: async recovery failed." + diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT" + exit 1 +fi + +echo "✅ Outputs match: recovery successful." diff --git a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py new file mode 100644 index 000000000000..549e85875025 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable +from unittest.mock import Mock + +import pytest + +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request, RequestStatus + +from .utils import (create_model_runner_output, create_request, + create_scheduler, create_vllm_config) + + +def _make_get_num_new_matched_tokens( + req_num_new_matched_tokens: dict[str, int], + async_load, +) -> Callable[[Request, int], tuple[int, bool]]: + + def get_num_new_matched_tokens(request: Request, + _: int) -> tuple[int, bool]: + value = req_num_new_matched_tokens.get(request.request_id, 0) + return value, async_load + + return get_num_new_matched_tokens + + +@pytest.fixture +def scheduler(): + vllm_config = create_vllm_config() + return create_scheduler(vllm_config) + + +@pytest.mark.parametrize( + "num_prompt_blocks," + "num_external_computed_blocks," + "invalid_block_idxs", + [ + (100, 99, {0, 98}), + (100, 99, {50, 98}), + (100, 99, {98}), + ], +) +def test_async_load_failure( + scheduler: Scheduler, + num_prompt_blocks: int, + num_external_computed_blocks: int, + invalid_block_idxs: set[int], +): + assert num_prompt_blocks >= num_external_computed_blocks + + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_external_computed_tokens = (num_external_computed_blocks * + scheduler.block_size) + + request1 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request1) + request2 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request2) + request3 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request3) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: num_external_computed_tokens, + request3.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, + async_load=True)) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + assert len(scheduler.waiting) == 3 + for request in scheduler.waiting: + assert request.num_computed_tokens == 0 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + + # Simulate a failure in loading some of request2 blocks. + (req2_block_ids, ) = scheduler.kv_cache_manager.get_block_ids( + request2.request_id) + invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs} + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving={request1.request_id, request3.request_id}, + invalid_block_ids=invalid_block_ids, + use_eos=True) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + min_invalid_block_idx = min(invalid_block_idxs) + + assert len(scheduler.waiting) == 3 + for request in scheduler.waiting: + if request.request_id == request2.request_id: + assert request.num_computed_tokens == (min_invalid_block_idx * + scheduler.block_size) + else: + assert request.num_computed_tokens == 0 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert scheduler.failed_recving_kv_req_ids == {request2.request_id} + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + + +@pytest.mark.parametrize( + "num_prompt_blocks," + "num_external_computed_blocks," + "invalid_block_idxs", + [ + (100, 99, {0, 98}), + (100, 99, {50, 98}), + (100, 99, {98}), + ], +) +def test_sync_load_failure( + scheduler: Scheduler, + num_prompt_blocks: int, + num_external_computed_blocks: int, + invalid_block_idxs: set[int], +): + assert num_prompt_blocks >= num_external_computed_blocks + + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_external_computed_tokens = (num_external_computed_blocks * + scheduler.block_size) + + request1 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request1) + request2 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request2) + request3 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request3) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: num_external_computed_tokens, + request3.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, + async_load=False)) + scheduler.connector.request_finished.return_value = (False, None) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + # req_id -> num_computed_tokens + expected_computed_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: num_external_computed_tokens, + request3.request_id: num_external_computed_tokens, + } + + assert len(scheduler.running) == 3 + assert len(scheduler_output.scheduled_new_reqs) == 3 + for request in scheduler_output.scheduled_new_reqs: + assert request.num_computed_tokens == expected_computed_tokens[ + request.req_id] + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + + # Simulate a failure in loading some of request2 blocks. + req2_block_ids = scheduler_output.scheduled_new_reqs[1].block_ids[0] + invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs} + model_runner_output = create_model_runner_output( + [request1, request2, request3], + invalid_block_ids=invalid_block_ids, + use_eos=True) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == request2.request_id + assert scheduler.running[0].num_computed_tokens == ( + min(invalid_block_idxs) * scheduler.block_size) + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + assert scheduler.connector.request_finished.call_count == 2 + + +@pytest.mark.parametrize( + "num_prompt_blocks," + "num_external_computed_blocks," + "num_common_prefix_blocks," + "invalid_block_idxs", + [ + (100, 99, 50, {0, 49}), + (100, 99, 50, {25, 49}), + (100, 99, 50, {49}), + ], +) +def test_sync_load_failure_with_shared_blocks( + scheduler: Scheduler, + num_prompt_blocks: int, + num_external_computed_blocks: int, + num_common_prefix_blocks: int, + invalid_block_idxs: set[int], +): + assert (num_prompt_blocks >= num_external_computed_blocks >= + num_common_prefix_blocks) + + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_external_computed_tokens = (num_external_computed_blocks * + scheduler.block_size) + common_prefix_len = num_common_prefix_blocks * scheduler.block_size + + request1 = create_request(num_tokens=num_prompt_tokens, + common_prefix_len=common_prefix_len) + scheduler.add_request(request=request1) + request2 = create_request(num_tokens=num_prompt_tokens, + common_prefix_len=common_prefix_len) + scheduler.add_request(request=request2) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, + async_load=False)) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + # req_id -> num_computed_tokens + expected_computed_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: common_prefix_len, + } + + assert len(scheduler.running) == 2 + assert len(scheduler_output.scheduled_new_reqs) == 2 + for request in scheduler_output.scheduled_new_reqs: + assert request.num_computed_tokens == expected_computed_tokens[ + request.req_id] + assert scheduler.connector.get_num_new_matched_tokens.call_count == 2 + + # Simulate a failure in loading some of the shared blocks. + req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs} + model_runner_output = create_model_runner_output( + [request1, request2], + invalid_block_ids=invalid_block_ids, + use_eos=True) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + # req_id -> num_computed_tokens + # all the common prefix blocks will be computed by request1 + expected_computed_tokens = { + request1.request_id: min(invalid_block_idxs) * scheduler.block_size, + request2.request_id: common_prefix_len, + } + + assert len(scheduler.running) == 2 + for request in scheduler.running: + assert request.num_computed_tokens == expected_computed_tokens[ + request.request_id] + assert scheduler.connector.get_num_new_matched_tokens.call_count == 2 + + +@pytest.mark.parametrize( + "num_prompt_blocks," + "num_external_computed_blocks," + "invalid_block_idxs", + [ + (100, 99, {0, 50, 98}), + (100, 99, {98, 50, 0}), + ], +) +def test_async_progressive_load_failure( + scheduler: Scheduler, + num_prompt_blocks: int, + num_external_computed_blocks: int, + invalid_block_idxs: set[int], +): + assert num_prompt_blocks >= num_external_computed_blocks + + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_external_computed_tokens = (num_external_computed_blocks * + scheduler.block_size) + + request = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, + async_load=True)) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + assert len(scheduler.waiting) == 1 + assert scheduler.waiting.peek_request().request_id == request.request_id + assert request.num_computed_tokens == 0 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 + + min_invalid_block_idx = max(invalid_block_idxs) + 1 + # Simulate failures when progressively loading request blocks. + for invalid_block_idx in invalid_block_idxs: + (req_block_ids, ) = scheduler.kv_cache_manager.get_block_ids( + request.request_id) + invalid_block_ids = {req_block_ids[invalid_block_idx]} + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving=set(), + invalid_block_ids=invalid_block_ids, + use_eos=True) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx) + + assert len(scheduler.waiting) == 1 + assert scheduler.waiting.peek_request( + ).request_id == request.request_id + assert request.num_computed_tokens == (min_invalid_block_idx * + scheduler.block_size) + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert scheduler.failed_recving_kv_req_ids == {request.request_id} + assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index f9a4d2fb4de4..f728b25d7834 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -281,8 +281,8 @@ def _run(self, decoded_tokens: list[int]): model_runner_output = create_model_runner_output( reqs=self.scheduler.running, - finished_sending=list(finished_sending), - finished_recving=list(finished_recving), + finished_sending=finished_sending, + finished_recving=finished_recving, token_id=token_id) if self.scheduler.running: diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggreagator.py index 5d2b27a9eb4d..665c55d72fa2 100644 --- a/tests/v1/kv_connector/unit/test_output_aggreagator.py +++ b/tests/v1/kv_connector/unit/test_output_aggreagator.py @@ -11,26 +11,38 @@ class DummyModelRunnerOutput(ModelRunnerOutput): def __init__(self, finished_sending: Optional[set[str]] = None, - finished_recving: Optional[set[str]] = None): + finished_recving: Optional[set[str]] = None, + invalid_block_ids: Optional[set[int]] = None): self.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, - ) + invalid_block_ids=invalid_block_ids or set()) def __repr__(self): return ( f"DummyModelRunnerOutput(" f"finished_sending={self.kv_connector_output.finished_sending}," - f"finished_recving={self.kv_connector_output.finished_recving})") + f"finished_recving={self.kv_connector_output.finished_recving})" + f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})") def test_aggregate_workers_output(): aggregator = KVOutputAggregator(world_size=2) + output1 = DummyModelRunnerOutput() + output2 = DummyModelRunnerOutput() + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + aggregated = aggregated.kv_connector_output + assert aggregated.finished_sending is None + assert aggregated.finished_recving is None + assert not aggregated.invalid_block_ids + output1 = DummyModelRunnerOutput(finished_sending={'req1'}, finished_recving={'req2'}) - output2 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) + output2 = DummyModelRunnerOutput(invalid_block_ids={1}) aggregated = aggregator.aggregate([output1, output2]) @@ -38,11 +50,10 @@ def test_aggregate_workers_output(): aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving is None + assert aggregated.invalid_block_ids == {1} - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving=None) + output1 = DummyModelRunnerOutput(invalid_block_ids={2}) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}) aggregated = aggregator.aggregate([output1, output2]) @@ -50,11 +61,11 @@ def test_aggregate_workers_output(): aggregated = aggregated.kv_connector_output assert aggregated.finished_sending == {'req1'} assert aggregated.finished_recving is None + assert aggregated.invalid_block_ids == {2} - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) + output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4}) + output2 = DummyModelRunnerOutput(finished_recving={'req2'}, + invalid_block_ids={4, 5}) aggregated = aggregator.aggregate([output1, output2]) @@ -62,6 +73,7 @@ def test_aggregate_workers_output(): aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving == {'req2'} + assert aggregated.invalid_block_ids == {3, 4, 5} def test_async_aggregate_workers_output(): @@ -71,10 +83,26 @@ def test_async_aggregate_workers_output(): future2: Future[DummyModelRunnerOutput] = Future() result_future = aggregator.async_aggregate([future1, future2]) + output1 = DummyModelRunnerOutput() + output2 = DummyModelRunnerOutput() + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + aggregated = aggregated.kv_connector_output + assert aggregated.finished_sending is None + assert aggregated.finished_recving is None + assert not aggregated.invalid_block_ids + + future1 = Future() + future2 = Future() + result_future = aggregator.async_aggregate([future1, future2]) + output1 = DummyModelRunnerOutput(finished_sending={'req1'}, finished_recving={'req2'}) - output2 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) + output2 = DummyModelRunnerOutput(invalid_block_ids={1}) future1.set_result(output1) future2.set_result(output2) @@ -84,15 +112,14 @@ def test_async_aggregate_workers_output(): aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving is None + assert aggregated.invalid_block_ids == {1} future1 = Future() future2 = Future() result_future = aggregator.async_aggregate([future1, future2]) - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving=None) + output1 = DummyModelRunnerOutput(invalid_block_ids={2}) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}) future1.set_result(output1) future2.set_result(output2) @@ -102,15 +129,15 @@ def test_async_aggregate_workers_output(): aggregated = aggregated.kv_connector_output assert aggregated.finished_sending == {'req1'} assert aggregated.finished_recving is None + assert aggregated.invalid_block_ids == {2} future1 = Future() future2 = Future() result_future = aggregator.async_aggregate([future1, future2]) - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) + output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4}) + output2 = DummyModelRunnerOutput(finished_recving={'req2'}, + invalid_block_ids={4, 5}) future1.set_result(output1) future2.set_result(output2) @@ -120,3 +147,4 @@ def test_async_aggregate_workers_output(): aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving == {'req2'} + assert aggregated.invalid_block_ids == {3, 4, 5} diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 380e72a15633..2d0cd35e0606 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -88,7 +88,7 @@ def test_basic_lifecycle(): # (3b): execute_model() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=[request_id]) + finished_sending={request_id}) # (3c): update_from_output() scheduler.update_from_output(scheduler_output, model_runner_output) @@ -135,7 +135,7 @@ def test_short_prompt_lifecycle(): scheduler_output = scheduler.schedule() # Use create_model_runner_output to pass kv_connector_output along model_runner_output = create_model_runner_output( - reqs=[request], finished_sending=[request.request_id]) + reqs=[request], finished_sending={request.request_id}) scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) @@ -191,6 +191,6 @@ def test_prefix_cache_lifecycle(): scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=[request_remote.request_id]) + finished_sending={request_remote.request_id}) scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 21fec5344255..2f5cb6fd9112 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -74,7 +74,7 @@ def test_basic_lifecycle(): # (2b): forward(): request finishes recv. model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving=[request_id]) + finished_recving={request_id}) # (2c): update_from_output(): engine_core_outputs = scheduler.update_from_output(scheduler_output, @@ -193,7 +193,7 @@ def test_interleaved_lifecycle(): model_runner_output = create_model_runner_output( [request_local_a, request_local_b], - finished_recving=[request_remote.request_id]) + finished_recving={request_remote.request_id}) scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 5: RECVed KVs are sent to ModelRunner. @@ -242,16 +242,16 @@ def test_no_spurious_prefix_caching(): request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, + common_prefix_len=NUM_TOKENS, do_remote_prefill=True, - use_all_1s_for_prompt_tokens=True, ) request_local = create_request( request_id=2, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, + common_prefix_len=NUM_TOKENS, do_remote_prefill=False, - use_all_1s_for_prompt_tokens=True, ) # Schedule the remote prefill request. This should not @@ -318,7 +318,7 @@ def test_full_block_prompt(): scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving=[request_id]) + finished_recving={request_id}) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.waiting) == 1 assert (request_id in scheduler.finished_recving_kv_req_ids) @@ -398,7 +398,7 @@ def test_cannot_schedule_after_recv(): # Step 3: finish recving (5 blocks in use) scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output( - reqs=[request_normal], finished_recving=[request_remote.request_id]) + reqs=[request_normal], finished_recving={request_remote.request_id}) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 @@ -512,7 +512,7 @@ def test_cannot_recv(): # Step 5: finish recving (5 blocks in use) scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output( - reqs=[], finished_recving=[request_remote.request_id]) + reqs=[], finished_recving={request_remote.request_id}) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index de52668e3dcf..3928cdc37b9d 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile from collections import defaultdict +from itertools import count from typing import Any, Callable, Optional import torch @@ -61,12 +62,15 @@ def create_vllm_config( max_num_seqs: int = 16, max_num_batched_tokens: int = 64, block_size: int = 16, + max_model_len: int = 10000, + enable_chunked_prefill: bool = True, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, + max_model_len=max_model_len, + enable_chunked_prefill=enable_chunked_prefill, ) model_config = ModelConfig( model=model, @@ -117,19 +121,27 @@ def create_scheduler( ) +_request_count = count(1) _none_hash_initialized = False -def create_request(request_id: int, - num_tokens: int = 10, - max_tokens: int = 16, - do_remote_decode: bool = False, - do_remote_prefill: bool = False, - use_all_1s_for_prompt_tokens: bool = False, - num_remote_blocks: int = 3, - block_size: int = 16, - hash_fn: Callable = sha256) -> Request: +def create_request( + request_id: Optional[int] = None, + num_tokens: int = 10, + common_prefix_len=0, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + num_remote_blocks: int = 3, + block_size: int = 16, + hash_fn: Callable = sha256, +) -> Request: """Make dummy request for testing.""" + assert num_tokens >= common_prefix_len >= 0 + + if request_id is None: + request_id = next(_request_count) + global _none_hash_initialized if not _none_hash_initialized: init_none_hash(hash_fn) @@ -153,10 +165,9 @@ def create_request(request_id: int, max_tokens = 1 if do_remote_decode else max_tokens sampling_params = SamplingParams(max_tokens=max_tokens) - if use_all_1s_for_prompt_tokens: - prompt_token_ids = [1] * num_tokens - else: - prompt_token_ids = [i * request_id for i in range(num_tokens)] + common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else [] + suffix = [i * request_id for i in range(num_tokens - common_prefix_len)] + prompt_token_ids = common_prefix + suffix req = Request( request_id=f"id-{request_id}", @@ -173,8 +184,9 @@ def create_request(request_id: int, def create_model_runner_output( reqs: list[Request], - finished_sending: Optional[list[str]] = None, - finished_recving: Optional[list[str]] = None, + finished_sending: Optional[set[str]] = None, + finished_recving: Optional[set[str]] = None, + invalid_block_ids: Optional[set[int]] = None, use_eos: bool = False, token_id: int = 0, ) -> ModelRunnerOutput: @@ -189,10 +201,11 @@ def create_model_runner_output( sampled_token_ids = [[sampled_token] for _ in req_ids] kv_connector_output = None if ( - finished_sending is None - and finished_recving is None) else KVConnectorOutput( + finished_sending is None and finished_recving is None + and invalid_block_ids is None) else KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, + invalid_block_ids=invalid_block_ids or set(), ) # Make output data structure. diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 49a7a61e1889..23d7ce4cefa3 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -250,6 +250,7 @@ def test_update_states_request_resumed(model_runner, dist_init): new_token_ids=[[]], new_block_ids=([[0]], ), num_computed_tokens=[0], + num_output_tokens=[0], ) scheduler_output = SchedulerOutput( diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index efa4c9abf47f..103fba41fcb4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -117,7 +117,7 @@ def get_kv_connector_cache_layout(): class KVOutputAggregator: - """Utility class to aggregate the output of all workers into a single + """Utility class to aggregate the output of all workers into a single output corresponding to Rank 0 for scheduler.""" def __init__(self, world_size: int): @@ -143,6 +143,7 @@ def update_finished_set(req_ids: Optional[set[str]], finished_sending = set[str]() finished_recving = set[str]() aggregated_kv_connector_stats = None + invalid_block_ids = set[int]() for model_runner_output in outputs: output = model_runner_output.kv_connector_output if not output: @@ -165,6 +166,8 @@ def update_finished_set(req_ids: Optional[set[str]], aggregated_kv_connector_stats = \ aggregated_kv_connector_stats.aggregate(kv_connector_stats) + invalid_block_ids |= output.invalid_block_ids + # select output of the worker specified by output_rank output = outputs[output_rank] @@ -172,6 +175,7 @@ def update_finished_set(req_ids: Optional[set[str]], finished_sending=finished_sending or None, finished_recving=finished_recving or None, kv_connector_stats=aggregated_kv_connector_stats or None, + invalid_block_ids=invalid_block_ids, ) return output diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 184d0a62f2c3..e3b4bcbfd1e6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -229,6 +229,26 @@ def get_finished( """ return None, None + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Get the set of block IDs that failed to load. + + Returns: + Set of block IDs that encountered load errors. + Empty set if no load errors occurred. + + Notes: + - Applies to both sync- and async-loading requests. + - Async loading: failed blocks may be reported in any forward pass + up to and including the pass where the request ID is returned by + `get_finished()`. Even if failures occur, the request must still + be reported via `get_finished()`, and the failed block IDs must + appear here no later than that same pass. + - Sync loading: failed blocks should be reported in the forward + pass in which they are detected. + """ + return set() + def shutdown(self): """ Shutdown the connector. This is called when the worker process @@ -264,14 +284,21 @@ def get_num_new_matched_tokens( Returns: A tuple with the following elements: - - An optional number of tokens that can be loaded from the - external KV cache beyond what is already computed. + - An optional number of tokens that can be loaded from the + external KV cache beyond what is already computed. If None, it means that the connector needs more time to determine the number of matched tokens, and the scheduler should query for this request again later. - `True` if external KV cache tokens will be loaded asynchronously (between scheduler steps). Must be 'False' if the first element is 0. + + Notes: + The connector should only consider the largest prefix of prompt- + tokens for which KV cache is actually available at the time of the + call. If the cache cannot be loaded for some tokens (e.g., due to + connectivity issues or eviction), those tokens must not be taken + into account. """ pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 6836a71e58d6..a7713ba326fc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -189,6 +189,12 @@ def get_finished( return finished_sending or None, finished_recving or None + def get_block_ids_with_load_errors(self) -> set[int]: + agg_block_ids: set[int] = set() + for c in self._connectors: + agg_block_ids |= c.get_block_ids_with_load_errors() + return agg_block_ids + # ============================== # Scheduler-side methods # ============================== diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 48fa1a82c677..c9949d81465c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional import safetensors @@ -55,10 +55,7 @@ def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, @dataclass class SharedStorageConnectorMetadata(KVConnectorMetadata): - requests: list[ReqMeta] - - def __init__(self): - self.requests = [] + requests: list[ReqMeta] = field(default_factory=list) def add_request( self, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 3cc738304821..617a724a1ad2 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -211,7 +211,7 @@ def cache_full_blocks( block_size: Number of tokens in each block. kv_cache_group_id: The id of the KV cache group. """ - if num_cached_blocks == num_full_blocks: + if num_cached_blocks >= num_full_blocks: return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] assert len(request.block_hashes) >= num_full_blocks diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 209fc2a4404f..6874e713aff3 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -101,6 +101,7 @@ class CachedRequestData: new_token_ids: list[list[int]] new_block_ids: list[Optional[tuple[list[int], ...]]] num_computed_tokens: list[int] + num_output_tokens: list[int] @property def num_reqs(self) -> int: @@ -114,6 +115,7 @@ def make_empty(cls) -> CachedRequestData: new_token_ids=[], new_block_ids=[], num_computed_tokens=[], + num_output_tokens=[], ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8d6ea887142d..d4be1b06b3b2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -133,6 +133,7 @@ def __init__( # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() + self.failed_recving_kv_req_ids: set[str] = set() # Encoder-related. # Calculate encoder cache size if applicable @@ -671,6 +672,7 @@ def _make_cached_request_data( new_token_ids: list[list[int]] = [] new_block_ids: list[Optional[tuple[list[int], ...]]] = [] num_computed_tokens: list[int] = [] + num_output_tokens: list[int] = [] use_connector = self.connector is not None for req in itertools.chain(running_reqs, resumed_reqs): @@ -695,6 +697,7 @@ def _make_cached_request_data( new_block_ids.append( req_to_new_blocks[req_id].get_block_ids(allow_none=True)) num_computed_tokens.append(req.num_computed_tokens) + num_output_tokens.append(len(req.output_token_ids)) # Because resumed_reqs is usually empty, it is more efficient to do # in-place appending so that we don't need to allocate a new list. resumed_from_preemption = [False] * len(running_reqs) @@ -706,6 +709,7 @@ def _make_cached_request_data( new_token_ids=new_token_ids, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, + num_output_tokens=num_output_tokens, ) def _try_schedule_encoder_inputs( @@ -878,6 +882,14 @@ def update_from_output( kv_connector_stats = (kv_connector_output.kv_connector_stats if kv_connector_output else None) + failed_kv_load_req_ids = None + if kv_connector_output and kv_connector_output.invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. + failed_kv_load_req_ids = self._handle_invalid_blocks( + kv_connector_output.invalid_block_ids) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best # to avoid expensive operations inside the loop. @@ -885,6 +897,9 @@ def update_from_output( stopped_preempted_reqs: set[Request] = set() for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): assert num_tokens_scheduled > 0 + if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: + # Skip requests that were recovered from KV load failure + continue request = self.requests.get(req_id) if request is None: # The request is already finished. This can happen if the @@ -988,9 +1003,8 @@ def update_from_output( self.waiting.remove_requests(stopped_preempted_reqs) # KV Connector: update state for finished KV Transfers. - if model_runner_output.kv_connector_output: - self._update_from_kv_xfer_finished( - model_runner_output.kv_connector_output) + if kv_connector_output: + self._update_from_kv_xfer_finished(kv_connector_output) # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. @@ -1252,18 +1266,33 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: if request.request_id not in self.finished_recving_kv_req_ids: return False - # Now that the blocks are ready, actually cache them. - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) - num_computed_tokens = len(block_ids) * self.block_size - # Handle the case where num request tokens less than one block. - num_computed_tokens = min(num_computed_tokens, request.num_tokens) - if num_computed_tokens == request.num_tokens: - num_computed_tokens -= 1 - # This will cache the blocks iff caching is enabled. - self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + if request.request_id in self.failed_recving_kv_req_ids: + # Request had KV load failures; num_computed_tokens was already + # updated in _update_requests_with_invalid_blocks + if request.num_computed_tokens: + # Cache any valid computed tokens. + self.kv_cache_manager.cache_blocks(request, + request.num_computed_tokens) + else: + # No valid computed tokens, release allocated blocks. + # There may be a local cache hit on retry. + self.kv_cache_manager.free(request) - # Update the request state for scheduling. - request.num_computed_tokens = num_computed_tokens + self.failed_recving_kv_req_ids.remove(request.request_id) + else: + # Now that the blocks are ready, actually cache them. + (block_ids, ) = self.kv_cache_manager.get_block_ids( + request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + # Handle the case where num request tokens less than one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + # This will cache the blocks iff caching is enabled. + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens # Return that we are ready. self.finished_recving_kv_req_ids.remove(request.request_id) @@ -1296,3 +1325,134 @@ def _update_from_kv_xfer_finished(self, "but the request is already freed.", req_id) else: self._free_blocks(self.requests[req_id]) + + def _update_requests_with_invalid_blocks( + self, requests: Iterable[Request], + invalid_block_ids: set[int]) -> tuple[set[str], int]: + """ + Identify and update requests affected by invalid KV cache blocks. + + This method scans the given requests, detects those with invalid blocks + and adjusts their `num_computed_tokens` to the longest valid prefix. + For observability, it also accumulates the total number of tokens that + will need to be recomputed across all affected requests. + + Args: + requests: The set of requests to scan for invalid blocks. + invalid_block_ids: IDs of invalid blocks. + + Returns: + tuple: + - affected_req_ids (set[str]): IDs of requests impacted by + invalid blocks. + - total_affected_tokens (int): Total number of tokens that must + be recomputed across all affected requests (for observability). + """ + affected_req_ids: set[str] = set() + total_affected_tokens = 0 + # If a block is invalid and shared by multiple requests in the batch, + # these requests must be rescheduled, but only the first will recompute + # it. This set tracks blocks already marked for recomputation. + marked_invalid_block_ids: set[int] = set() + for request in requests: + is_affected = False + marked_invalid_block = False + req_id = request.request_id + # TODO (davidb): add support for hybrid memory allocator + (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id) + # We iterate only over blocks that may contain externally computed + # tokens + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + # Async loading. If num_computed_tokens is set it implies we + # already processed some block failures for it in a prior step + req_num_computed_tokens = ( + request.num_computed_tokens if req_id + in self.failed_recving_kv_req_ids else len(req_block_ids) * + self.block_size) + else: + # Sync loading. num_computed_tokens includes new tokens + req_num_computed_tokens = request.num_cached_tokens + + req_num_computed_blocks = (req_num_computed_tokens + + self.block_size - 1) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), + req_block_ids): + + if block_id not in invalid_block_ids: + continue + + is_affected = True + + if block_id in marked_invalid_block_ids: + # This invalid block is shared with a previous request + # and was already marked for recomputation. + # This means this request can still consider this block + # as computed when rescheduled. + # Currently this only applies to sync loading; Async + # loading does not yet support block sharing + continue + + marked_invalid_block_ids.add(block_id) + + if marked_invalid_block: + # This request has already marked an invalid block for + # recomputation and updated its num_computed_tokens. + continue + + marked_invalid_block = True + # Truncate the computed tokens at the first failed block + request.num_computed_tokens = idx * self.block_size + total_affected_tokens += (req_num_computed_tokens - + request.num_computed_tokens) + + if is_affected: + if not marked_invalid_block: + # All invalid blocks of this request are shared with + # previous requests and will be recomputed by them. + # Revert to considering only cached tokens as computed. + # Currently this only applies to sync loading; Async + # loading does not yet support block sharing + total_affected_tokens += (request.num_computed_tokens - + request.num_cached_tokens) + request.num_computed_tokens = request.num_cached_tokens + + affected_req_ids.add(request.request_id) + + return (affected_req_ids, total_affected_tokens) + + def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: + total_requests_to_reschedule = 0 + total_tokens_to_reschedule = 0 + + # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- + async_load_reqs = ( + req for req in self.waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + async_affected_req_ids, num_tokens_to_reschedule = ( + self._update_requests_with_invalid_blocks(async_load_reqs, + invalid_block_ids)) + + total_requests_to_reschedule += len(async_affected_req_ids) + total_tokens_to_reschedule += num_tokens_to_reschedule + + # Mark requests with async KV load failures; they will be rescheduled + # once loading completes + self.failed_recving_kv_req_ids |= async_affected_req_ids + + # --- Handle sync KV loads (running requests) --- + sync_affected_req_ids, num_tokens_to_reschedule = ( + self._update_requests_with_invalid_blocks(self.running, + invalid_block_ids)) + + total_requests_to_reschedule += len(sync_affected_req_ids) + total_tokens_to_reschedule += num_tokens_to_reschedule + + if total_requests_to_reschedule: + logger.warning( + "Recovered from KV load failure: " + "%d request(s) rescheduled (%d tokens affected).", + total_requests_to_reschedule, total_tokens_to_reschedule) + + # Return the IDs of affected running requests to skip in + # update_from_output. + return sync_affected_req_ids diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e889f7804e84..4ecd9c8157e2 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -142,6 +142,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: num_cached_blocks = self.num_cached_block[request.request_id] num_full_blocks = num_tokens // self.block_size + if num_cached_blocks >= num_full_blocks: + return + self.block_pool.cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 01f3676abd92..d15cdf365962 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, NamedTuple, Optional, Union import torch @@ -87,10 +87,13 @@ class KVConnectorOutput: finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None kv_connector_stats: Optional["KVConnectorStats"] = None + # IDs of externally computed KV blocks that failed to load. + # Requests referencing these blocks should be rescheduled to recompute them. + invalid_block_ids: set[int] = field(default_factory=set) def is_empty(self): return (not self.finished_sending and not self.finished_recving - and not self.kv_connector_stats) + and not self.kv_connector_stats and not self.invalid_block_ids) # ModelRunnerOutput is serialized and sent to the scheduler process. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bb5c3ea74293..947a7ed640d7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -634,8 +634,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] + num_output_tokens = req_data.num_output_tokens[i] # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: @@ -653,6 +655,21 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: elif num_new_tokens > 0: req_state.output_token_ids.extend( new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + # Some output tokens were discarded due to a sync-KV-load + # failure. Align the cached state. + del req_state.output_token_ids[num_output_tokens:] + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is not None: + old_end_idx = self.input_batch.num_tokens_no_spec[ + req_index] + end_idx = self.input_batch.num_prompt_tokens[ + req_index] + num_output_tokens + self.input_batch.num_tokens[req_index] = end_idx + self.input_batch.num_tokens_no_spec[req_index] = end_idx + self.input_batch.is_token_ids[req_index, + end_idx:old_end_idx] = False # Update the block IDs. if not resumed_from_preemption: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 8c75e8914857..a135a594ac6f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -464,8 +464,7 @@ def execute_model( # In case of PP with kv transfer, we need to pass through the # kv_connector_output - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): + if kv_connector_output.is_empty(): return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 7eaff924ecc1..cdc0d317fffb 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -75,8 +75,7 @@ def kv_connector_no_forward(scheduler_output: "SchedulerOutput", scheduler_output, wait_for_save=False) as kv_connector_output: pass - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): + if kv_connector_output.is_empty(): return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) @@ -120,6 +119,8 @@ def _get_kv_connector_output( output.finished_sending, output.finished_recving = ( kv_connector.get_finished(scheduler_output.finished_req_ids)) + output.invalid_block_ids = ( + kv_connector.get_block_ids_with_load_errors()) output.kv_connector_stats = KVConnectorModelRunnerMixin.\ get_kv_connector_stats()