diff --git a/examples/disaggregated_serving_xpyd/disaggregated_proxy_p2p_deepseek_v32_xpyd.py b/examples/disaggregated_serving_xpyd/disaggregated_proxy_p2p_deepseek_v32_xpyd.py new file mode 100644 index 00000000..8e426b94 --- /dev/null +++ b/examples/disaggregated_serving_xpyd/disaggregated_proxy_p2p_deepseek_v32_xpyd.py @@ -0,0 +1,585 @@ +import argparse +import copy +import importlib.util +import json +import os +import socket +import threading +import time +import uuid +from pathlib import Path +from typing import Any + +import aiohttp +import msgpack +import zmq +from quart import Quart, jsonify, make_response, request + + +DEFAULT_HTTP_HOST = "0.0.0.0" +DEFAULT_HTTP_PORT = 10001 +DEFAULT_DISCOVERY_HOST = "0.0.0.0" +DEFAULT_DISCOVERY_PORT = 30002 +DEFAULT_INSTANCE_TTL_SECONDS = 5 +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + +count = 0 +prefill_instances: dict[str, tuple[str, float]] = {} +decode_instances: dict[str, tuple[str, float]] = {} +prefill_cv = threading.Condition() +decode_cv = threading.Condition() +instance_ttl_seconds = DEFAULT_INSTANCE_TTL_SECONDS + + +def _load_deepseek_v32_encoding_module(): + # Prefer installed package import (works regardless of script location). + try: + import vllm.tokenizers.deepseek_v32_encoding as _m + return _m + except ImportError: + pass + + # Fallback: locate the file relative to the project root. + # The script may live in any subdirectory (e.g. shell/, reproduction/), + # so walk upward until we find the vllm package directory. + script_dir = Path(__file__).resolve().parent + for base in [script_dir, script_dir.parent, script_dir.parent.parent]: + module_path = base / "vllm" / "tokenizers" / "deepseek_v32_encoding.py" + if module_path.exists(): + break + else: + raise ImportError( + "Unable to locate deepseek_v32_encoding.py. " + "Install the vllm package or place the script under the vllm project root." + ) + + spec = importlib.util.spec_from_file_location( + "_router_nccl_deepseek_v32_encoding", + module_path, + ) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load DeepSeek V3.2 encoder from {module_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +_deepseek_v32_encoding = _load_deepseek_v32_encoding_module() +encode_messages = _deepseek_v32_encoding.encode_messages +parse_message_from_completion_text = ( + _deepseek_v32_encoding.parse_message_from_completion_text +) + + +def random_uuid() -> str: + return uuid.uuid4().hex + + +def _prune_expired(instances: dict[str, tuple[str, float]]) -> None: + now = time.time() + expired = [key for key, (_, deadline) in instances.items() if deadline <= now] + for key in expired: + zmq_addr, deadline = instances.pop(key) + print(f"remove instance http={key} zmq={zmq_addr} deadline={deadline}") + + +def _register_instance(role: str, http_address: str, zmq_address: str) -> None: + global prefill_instances + global decode_instances + + bucket = prefill_instances if role == "P" else decode_instances + cv = prefill_cv if role == "P" else decode_cv + deadline = time.time() + instance_ttl_seconds + + with cv: + node = bucket.get(http_address) + bucket[http_address] = (zmq_address, deadline) + _prune_expired(bucket) + if node is None: + print(f"add instance role={role} http={http_address} zmq={zmq_address}") + + +def _listen_for_register(poller: zmq.Poller, router_socket: Any) -> None: + while True: + socks = dict(poller.poll()) + if router_socket not in socks: + continue + + remote_address, message = router_socket.recv_multipart() + data = msgpack.loads(message) + role = data.get("type") + http_address = data.get("http_address") + zmq_address = data.get("zmq_address") + + if role not in {"P", "D"} or not http_address or not zmq_address: + print( + f"unexpected register remote={remote_address!r} data={data!r}", + ) + continue + + _register_instance(role, http_address, zmq_address) + + +def start_service_discovery(hostname: str, port: int) -> threading.Thread: + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("discovery port cannot be 0") + + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") + + poller = zmq.Poller() + poller.register(router_socket, zmq.POLLIN) + + thread = threading.Thread( + target=_listen_for_register, + args=(poller, router_socket), + daemon=True, + ) + thread.start() + print(f"service discovery listening on tcp://{hostname}:{port}") + return thread + + +def _choose_instance( + instances: dict[str, tuple[str, float]], + cv: threading.Condition, + idx: int, +) -> tuple[str, str] | None: + with cv: + _prune_expired(instances) + if not instances: + return None + items = list(instances.items()) + http_addr, (zmq_addr, _) = items[idx % len(items)] + return http_addr, zmq_addr + + +async def _post_request( + url: str, + data: dict[str, Any], + request_id: str, + auth_header: str | None, + *, + stream: bool = False, +): + headers = {"X-Request-Id": request_id} + if auth_header: + headers["Authorization"] = auth_header + elif os.environ.get("OPENAI_API_KEY"): + headers["Authorization"] = f"Bearer {os.environ['OPENAI_API_KEY']}" + + if not stream: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with session.post(url=url, json=data, headers=headers) as response: + body = await response.read() + return { + "ok": response.status == 200, + "status": response.status, + "body": body, + "content_type": response.headers.get( + "Content-Type", + "application/json", + ), + } + + session = aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) + try: + response = await session.post(url=url, json=data, headers=headers) + except Exception: + await session.close() + raise + + content_type = response.headers.get( + "Content-Type", + "application/json", + ) + if response.status != 200: + try: + body = await response.read() + return { + "ok": False, + "status": response.status, + "body": body, + "content_type": content_type, + } + finally: + response.close() + await session.close() + + async def _stream(): + try: + async for chunk in response.content.iter_chunked(8192): + yield chunk + finally: + response.close() + await session.close() + + return { + "ok": True, + "status": response.status, + "stream": _stream(), + "content_type": content_type, + } + + +def _build_request_id(prefill_zmq_addr: str, decode_zmq_addr: str) -> str: + return ( + f"___prefill_addr_{prefill_zmq_addr}___decode_addr_" + f"{decode_zmq_addr}_{random_uuid()}" + ) + + +def _is_chat_request_path(path: str) -> bool: + return path == "/v1/chat/completions" + + +def _error_response(message: str, status: int = 400): + return jsonify({"error": message}), status + + +def _normalize_message_content(message: dict[str, Any]) -> dict[str, Any]: + normalized = copy.deepcopy(message) + content = normalized.get("content") + if isinstance(content, list): + text_parts: list[str] = [] + for part in content: + if not isinstance(part, dict) or part.get("type") != "text": + raise ValueError( + "Only text content is supported in /v1/chat/completions bridge" + ) + text_parts.append(part.get("text", "")) + normalized["content"] = "".join(text_parts) + return normalized + + +def _get_thinking_mode(request_data: dict[str, Any]) -> str: + if request_data.get("thinking") or request_data.get("enable_thinking"): + return "thinking" + + reasoning_effort = request_data.get("reasoning_effort") + include_reasoning = request_data.get("include_reasoning", True) + if reasoning_effort not in (None, "none") and include_reasoning: + return "thinking" + + return "chat" + + +def _build_chat_prompt(request_data: dict[str, Any]) -> tuple[str, str]: + messages = request_data.get("messages") + if not isinstance(messages, list) or len(messages) == 0: + raise ValueError("`messages` must be a non-empty list") + + normalized_messages = [_normalize_message_content(msg) for msg in messages] + thinking_mode = _get_thinking_mode(request_data) + + system_metadata: dict[str, Any] = {} + if request_data.get("tools"): + system_metadata["tools"] = copy.deepcopy(request_data["tools"]) + if request_data.get("response_format"): + system_metadata["response_format"] = copy.deepcopy( + request_data["response_format"] + ) + if system_metadata: + normalized_messages.insert(0, {"role": "system", **system_metadata}) + + drop_thinking = normalized_messages[-1].get("role") in {"user", "developer"} + prompt = encode_messages( + normalized_messages, + thinking_mode=thinking_mode, + drop_thinking=drop_thinking, + ) + return prompt, thinking_mode + + +def _build_completion_request_from_chat( + request_data: dict[str, Any], +) -> tuple[dict[str, Any], str]: + if request_data.get("stream"): + raise ValueError( + "`stream=true` is not supported for bridged /v1/chat/completions" + ) + + prompt, thinking_mode = _build_chat_prompt(request_data) + + completion_request = copy.deepcopy(request_data) + for key in ( + "messages", + "tools", + "tool_choice", + "response_format", + "stream_options", + "reasoning_effort", + "include_reasoning", + "parallel_tool_calls", + "user", + "chat_template", + "chat_template_kwargs", + "add_generation_prompt", + "continue_final_message", + "add_special_tokens", + "documents", + "thinking", + "enable_thinking", + ): + completion_request.pop(key, None) + + max_completion_tokens = completion_request.pop("max_completion_tokens", None) + if max_completion_tokens is not None and completion_request.get("max_tokens") is None: + completion_request["max_tokens"] = max_completion_tokens + + completion_request["prompt"] = prompt + completion_request["stream"] = False + return completion_request, thinking_mode + + +def _materialize_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]: + materialized: list[dict[str, Any]] = [] + for tool_call in tool_calls: + function = tool_call.get("function", {}) + materialized.append( + { + "id": f"call_{random_uuid()}", + "type": tool_call.get("type", "function"), + "function": { + "name": function.get("name"), + "arguments": function.get("arguments", ""), + }, + } + ) + return materialized + + +def _convert_completion_to_chat_response( + completion_payload: dict[str, Any], + thinking_mode: str, + include_reasoning: bool, +) -> dict[str, Any]: + choices: list[dict[str, Any]] = [] + for choice in completion_payload.get("choices", []): + text = choice.get("text", "") + try: + parsed_message = parse_message_from_completion_text(text, thinking_mode) + except Exception: + parsed_message = { + "role": "assistant", + "content": text, + "reasoning": "", + "tool_calls": [], + } + + tool_calls = _materialize_tool_calls(parsed_message.get("tool_calls", [])) + content = parsed_message.get("content") + message: dict[str, Any] = { + "role": "assistant", + "content": content if content or not tool_calls else None, + } + if include_reasoning and parsed_message.get("reasoning"): + message["reasoning"] = parsed_message["reasoning"] + if tool_calls: + message["tool_calls"] = tool_calls + + finish_reason = choice.get("finish_reason") + if tool_calls and finish_reason in (None, "stop"): + finish_reason = "tool_calls" + + choices.append( + { + "index": choice.get("index", 0), + "message": message, + "logprobs": None, + "finish_reason": finish_reason or "stop", + "stop_reason": choice.get("stop_reason"), + "token_ids": choice.get("token_ids"), + } + ) + + response_id = completion_payload.get("id", f"chatcmpl-{random_uuid()}") + if isinstance(response_id, str) and response_id.startswith("cmpl-"): + response_id = "chat" + response_id + + return { + "id": response_id, + "object": "chat.completion", + "created": completion_payload.get("created", int(time.time())), + "model": completion_payload.get("model"), + "choices": choices, + "usage": completion_payload.get("usage"), + "service_tier": completion_payload.get("service_tier"), + "system_fingerprint": completion_payload.get("system_fingerprint"), + "prompt_logprobs": completion_payload.get("prompt_logprobs"), + "prompt_token_ids": completion_payload.get("prompt_token_ids"), + "kv_transfer_params": completion_payload.get("kv_transfer_params"), + } + + +@app.route("/health", methods=["GET"]) +async def health(): + with prefill_cv: + _prune_expired(prefill_instances) + prefill_count = len(prefill_instances) + with decode_cv: + _prune_expired(decode_instances) + decode_count = len(decode_instances) + return jsonify( + { + "status": "ok", + "prefill_instances": prefill_count, + "decode_instances": decode_count, + } + ) + + +@app.route("/debug/instances", methods=["GET"]) +async def debug_instances(): + with prefill_cv: + _prune_expired(prefill_instances) + prefills = dict(prefill_instances) + with decode_cv: + _prune_expired(decode_instances) + decodes = dict(decode_instances) + return jsonify({"prefill": prefills, "decode": decodes}) + + +@app.route("/v1/completions", methods=["POST"]) +@app.route("/v1/chat/completions", methods=["POST"]) +async def handle_request(): + global count + + original_request_data = await request.get_json() + auth_header = request.headers.get("Authorization") + is_chat_request = _is_chat_request_path(request.path) + include_reasoning = bool(original_request_data.get("include_reasoning", True)) + thinking_mode = "chat" + + if is_chat_request: + try: + request_data, thinking_mode = _build_completion_request_from_chat( + original_request_data + ) + except ValueError as exc: + return _error_response(str(exc), 400) + else: + request_data = original_request_data + upstream_path = "/v1/completions" if is_chat_request else request.path + + pair_index = count + count += 1 + + prefill = _choose_instance(prefill_instances, prefill_cv, pair_index) + decode = _choose_instance(decode_instances, decode_cv, pair_index) + + if prefill is None: + return ( + jsonify({"error": "no registered prefill instances"}), + 503, + ) + if decode is None: + return ( + jsonify({"error": "no registered decode instances"}), + 503, + ) + + prefill_addr, prefill_zmq_addr = prefill + decode_addr, decode_zmq_addr = decode + + print( + "route request " + f"[HTTP:{prefill_addr}, ZMQ:{prefill_zmq_addr}] -> " + f"[HTTP:{decode_addr}, ZMQ:{decode_zmq_addr}]" + ) + + prefill_request = dict(request_data) + prefill_request["max_tokens"] = 1 + prefill_request["stream"] = False + prefill_request.pop("stream_options", None) + if "max_completion_tokens" in prefill_request: + prefill_request["max_completion_tokens"] = 1 + should_stream = bool(request_data.get("stream")) + + request_id = _build_request_id(prefill_zmq_addr, decode_zmq_addr) + + prefill_result = await _post_request( + f"http://{prefill_addr}{upstream_path}", + prefill_request, + request_id, + auth_header, + stream=False, + ) + if not prefill_result["ok"]: + return make_response( + prefill_result["body"], + prefill_result["status"], + {"Content-Type": prefill_result["content_type"]}, + ) + + decode_result = await _post_request( + f"http://{decode_addr}{upstream_path}", + request_data, + request_id, + auth_header, + stream=should_stream, + ) + if not decode_result["ok"]: + return make_response( + decode_result["body"], + decode_result["status"], + {"Content-Type": decode_result["content_type"]}, + ) + + if is_chat_request: + try: + completion_payload = json.loads(decode_result["body"]) + except json.JSONDecodeError: + return _error_response("decode instance returned non-JSON completion body", 502) + + chat_payload = _convert_completion_to_chat_response( + completion_payload, + thinking_mode, + include_reasoning, + ) + return await make_response( + json.dumps(chat_payload, ensure_ascii=False), + decode_result["status"], + {"Content-Type": "application/json"}, + ) + + if should_stream: + response = await make_response( + decode_result["stream"], + decode_result["status"], + {"Content-Type": decode_result["content_type"]}, + ) + response.timeout = None + return response + + return await make_response( + decode_result["body"], + decode_result["status"], + {"Content-Type": decode_result["content_type"]}, + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default=DEFAULT_HTTP_HOST) + parser.add_argument("--http-port", type=int, default=DEFAULT_HTTP_PORT) + parser.add_argument("--discovery-host", default=DEFAULT_DISCOVERY_HOST) + parser.add_argument("--discovery-port", type=int, default=DEFAULT_DISCOVERY_PORT) + parser.add_argument("--instance-ttl-seconds", type=int, default=5) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + instance_ttl_seconds = args.instance_ttl_seconds + discovery_thread = start_service_discovery( + args.discovery_host, + args.discovery_port, + ) + app.run(host=args.host, port=args.http_port) + discovery_thread.join() diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 2c823db4..03335ae4 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -3,7 +3,7 @@ import os import logging -from vllm_fl.utils import get_op_config as _get_op_config +# from vllm_fl.utils import get_op_config as _get_op_config from . import version as version # PyTorch-style: vllm_fl.version.git_version @@ -116,3 +116,14 @@ def register_model(): ) except Exception as e: logger.error(f"Register GlmMoeDsa model error: {str(e)}") + + try: + from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory + KVConnectorFactory._registry.pop("P2pNcclConnector", None) + KVConnectorFactory.register_connector( + "P2pNcclConnector", + "vllm_fl.distributed.kv_transfer.p2p_flagcx_connector", + "P2pNcclConnector", + ) + except Exception as e: + logger.error(f"Register P2pFlagcxConnector error: {str(e)}") diff --git a/vllm_fl/distributed/kv_transfer/__init__.py b/vllm_fl/distributed/kv_transfer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_fl/distributed/kv_transfer/p2p_flagcx_connector.py b/vllm_fl/distributed/kv_transfer/p2p_flagcx_connector.py new file mode 100644 index 00000000..423cac27 --- /dev/null +++ b/vllm_fl/distributed/kv_transfer/p2p_flagcx_connector.py @@ -0,0 +1,599 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import regex as re +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( + P2pNcclEngine, +) +from vllm.distributed.parallel_state import get_world_group +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +# --------------------------------------------------------------------------- +# Detect FlagCX availability for KV cache transfer. +# If FLAGCX_PATH is set and libflagcx.so exists, use P2pFlagcxEngine; +# otherwise fall back to native NCCL with clear diagnostics. +# --------------------------------------------------------------------------- +import logging as _logging +_kv_logger = _logging.getLogger(__name__) + +_use_flagcx = False +_flagcx_path = os.getenv('FLAGCX_PATH') +if _flagcx_path: + _flagcx_so = os.path.join(_flagcx_path, "build/lib/libflagcx.so") + if not os.path.isdir(_flagcx_path): + _kv_logger.warning( + "\u26a0\ufe0f FLAGCX_PATH=%s is set but directory does not exist. " + "KV transfer will use native NCCL.", + _flagcx_path, + ) + elif not os.path.isfile(_flagcx_so): + _kv_logger.warning( + "\u26a0\ufe0f FLAGCX_PATH=%s is set but %s not found. " + "KV transfer will use native NCCL. " + "Did you build FlagCX? (cd $FLAGCX_PATH && make)", + _flagcx_path, + _flagcx_so, + ) + else: + try: + from vllm_fl.distributed.kv_transfer.p2p_flagcx_engine import ( + P2pFlagcxEngine, + ) + _use_flagcx = True + _kv_logger.warning( + "\u2705 FlagCX detected (FLAGCX_PATH=%s). " + "KV cache P2P transfer will use FlagCX backend.", + _flagcx_path, + ) + except Exception as _e: + _kv_logger.warning( + "\u26a0\ufe0f FLAGCX_PATH=%s is set and libflagcx.so exists, " + "but failed to import P2pFlagcxEngine: %s. " + "KV transfer will use native NCCL.", + _flagcx_path, + _e, + ) +else: + _kv_logger.info( + "FLAGCX_PATH not set. KV transfer will use native NCCL." + ) + +if TYPE_CHECKING: + from vllm.forward_context import ForwardContext + from vllm.v1.attention.backend import AttentionMetadata + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request Id + request_id: str + # Request block ids + block_ids: torch.Tensor + # Request num tokens + num_tokens: int + + @staticmethod + def make_meta( + request_id: str, token_ids: list[int], block_ids: list[int], block_size: int + ) -> "ReqMeta": + block_ids_tensor = torch.tensor(block_ids) + return ReqMeta( + request_id=request_id, + block_ids=block_ids_tensor, + num_tokens=len(token_ids), + ) + + +@dataclass +class P2pNcclConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request( + self, + request_id: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + ) -> None: + self.requests.append( + ReqMeta.make_meta(request_id, token_ids, block_ids, block_size) + ) + + +class P2pNcclConnector(KVConnectorBase_V1): + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig | None" = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Any] = {} + self.is_producer = self._kv_transfer_config.is_kv_producer + self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {} + + self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0 + self._local_rank = ( + get_world_group().local_rank if role == KVConnectorRole.WORKER else 0 + ) + + _EngineClass = P2pFlagcxEngine if _use_flagcx else P2pNcclEngine + if _use_flagcx: + logger.warning( + "\u2705 KV transfer engine: P2pFlagcxEngine (FlagCX backend)" + ) + else: + logger.warning( + "KV transfer engine: P2pNcclEngine (native NCCL backend)" + ) + + self.p2p_nccl_engine = ( + _EngineClass( + local_rank=self._local_rank, + config=self._kv_transfer_config, + hostname="", + port_offset=self._rank, + ) + if role == KVConnectorRole.WORKER + else None + ) + + # ============================== + # Worker-side methods + # ============================== + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + + # Only consumer/decode loads KV Cache + if self.is_producer: + return + + assert self.p2p_nccl_engine is not None + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + def inject_kv_into_layer( + layer: torch.Tensor, + kv_cache: torch.Tensor, + block_ids: torch.Tensor, + request_id: str, + ) -> None: + """ + Inject KV cache data into a given attention layer tensor. + + This function updates `layer` in-place with values from `kv_cache`, + handling different backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + If the number of provided block IDs does not match the number of KV + blocks, only the overlapping portion is updated, and a warning is + logged. + + Args: + layer (torch.Tensor): The attention layer KV tensor to update. + kv_cache (torch.Tensor): The KV cache tensor to inject. + block_ids (torch.Tensor): Indices of the blocks to update. + request_id (str): Request identifier used for logging. + + Returns: + None. The function modifies `layer` in-place. + """ + if layer.ndim == 3 or layer.shape[1] == 2: # MLA or FlashInfer + num_block = kv_cache.shape[0] + self.check_tensors_except_dim(layer, kv_cache, 0) + if len(block_ids) == num_block: + layer[block_ids, ...] = kv_cache + else: + layer[block_ids[:num_block], ...] = kv_cache + logger.warning( + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", + len(block_ids), + num_block, + request_id, + ) + + elif layer.shape[0] == 2: # FlashAttention + num_block = kv_cache.shape[1] + self.check_tensors_except_dim(layer, kv_cache, 1) + if len(block_ids) == num_block: + layer[:, block_ids, ...] = kv_cache + else: + layer[:, block_ids[:num_block], ...] = kv_cache + logger.warning( + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", + len(block_ids), + num_block, + request_id, + ) + + # Get the metadata + metadata: KVConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, P2pNcclConnectorMetadata) + + if metadata is None: + return + + # Load the KV for each request each layer + for request in metadata.requests: + request_id = request.request_id + ip, port = self.parse_request_id(request_id, False) + remote_address = ip + ":" + str(port + self._rank) + for layer_name in forward_context.no_compile_layers: + layer = forward_context.no_compile_layers[layer_name] + + # Only process layers that have kv_cache + # attribute (attention layers) Skip non-attention + # layers like FusedMoE + kv_cache = getattr(layer, "kv_cache", None) + if kv_cache is None: + continue + + layer = kv_cache[forward_context.virtual_engine] + + # Skip non-standard KV caches (e.g. V3.2 Indexer + # uint8 cache) that should not be transferred. + if layer.dtype == torch.uint8: + continue + kv_cache = self.p2p_nccl_engine.recv_tensor( + request.request_id + "#" + layer_name, remote_address + ) + + if kv_cache is None: + logger.warning("🚧kv_cache is None, %s", request.request_id) + continue + + inject_kv_into_layer( + layer, kv_cache, request.block_ids, request.request_id + ) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + + # Only producer/prefill saves KV Cache + if not self.is_producer: + return + + assert self.p2p_nccl_engine is not None + + def extract_kv_from_layer( + layer: torch.Tensor, + block_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Extract KV cache slices from a given attention layer tensor. + + This function handles multiple backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + Args: + layer (torch.Tensor): The KV cache from the attention layer. + block_ids (torch.Tensor): Indices of blocks to extract. + + Returns: + torch.Tensor: A tensor containing the extracted KV slices. + Returns None if the layout is unsupported. + """ + if layer.ndim == 3 or layer.shape[1] == 2: # MLA or FlashInfer + return layer[block_ids, ...] + + if layer.shape[0] == 2: # FlashAttention + return layer[:, block_ids, ...] + + return None + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, P2pNcclConnectorMetadata) + for request in connector_metadata.requests: + request_id = request.request_id + ip, port = self.parse_request_id(request_id, True) + remote_address = ip + ":" + str(port + self._rank) + + kv_cache = extract_kv_from_layer(kv_layer, request.block_ids) + if kv_cache is None: + logger.warning( + "🚧Unsupported KV cache layout for layer %s, " + "request_id:%s, shape:%s", + layer_name, + request_id, + kv_layer.shape, + ) + continue + self.p2p_nccl_engine.send_tensor( + request_id + "#" + layer_name, kv_cache, remote_address + ) + + def wait_for_save(self): + if self.is_producer: + assert self.p2p_nccl_engine is not None + self.p2p_nccl_engine.wait_for_sent() + + def get_finished( + self, finished_req_ids: set[str], **kwargs: Any + ) -> tuple[set[str] | None, set[str] | None]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer, + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + + assert self.p2p_nccl_engine is not None + + no_compile_layers = self._vllm_config.compilation_config.static_forward_context + return self.p2p_nccl_engine.get_finished(finished_req_ids, no_compile_layers) + + # ============================== + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + if self.is_producer: + return 0, False + + prompt_token_ids = request.prompt_token_ids or [] + num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens + + if num_external_tokens < 0: + num_external_tokens = 0 + + return num_external_tokens, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + """ + Update KVConnector state after block allocation. + """ + if not self.is_producer and num_external_tokens > 0: + self._requests_need_load[request.request_id] = ( + request, + blocks.get_block_ids()[0], + ) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + meta = P2pNcclConnectorMetadata() + + for new_req in scheduler_output.scheduled_new_reqs: + if self.is_producer: + num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[ + new_req.req_id + ] + num_tokens = num_scheduled_tokens + new_req.num_computed_tokens + # the request's prompt is chunked prefill + if num_tokens < len(new_req.prompt_token_ids or []): + # 'CachedRequestData' has no attribute 'prompt_token_ids' + self.chunked_prefill[new_req.req_id] = ( + new_req.block_ids[0], + new_req.prompt_token_ids, + ) + continue + # the request's prompt is not chunked prefill + meta.add_request( + request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids or [], + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) + continue + if new_req.req_id in self._requests_need_load: + meta.add_request( + request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids or [], + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) + self._requests_need_load.pop(new_req.req_id) + + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + num_computed_tokens = cached_reqs.num_computed_tokens[i] + new_block_ids = cached_reqs.new_block_ids[i] + resumed_from_preemption = req_id in cached_reqs.resumed_req_ids + + if self.is_producer: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_tokens = num_scheduled_tokens + num_computed_tokens + assert req_id in self.chunked_prefill + assert new_block_ids is not None + block_ids = new_block_ids[0] + if not resumed_from_preemption: + block_ids = self.chunked_prefill[req_id][0] + block_ids + prompt_token_ids = self.chunked_prefill[req_id][1] + assert prompt_token_ids is not None + # the request's prompt is chunked prefill again + if num_tokens < len(prompt_token_ids): + self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) + continue + # the request's prompt is all prefilled finally + meta.add_request( + request_id=req_id, + token_ids=prompt_token_ids, + block_ids=block_ids, + block_size=self._block_size, + ) + self.chunked_prefill.pop(req_id, None) + continue + + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not resumed_from_preemption: + break + if req_id in self._requests_need_load: + request, _ = self._requests_need_load.pop(req_id) + total_tokens = num_computed_tokens + 1 + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + assert new_block_ids is not None + block_ids = new_block_ids[0] + + meta.add_request( + request_id=req_id, + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + ) + + self._requests_need_load.clear() + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + + self.chunked_prefill.pop(request.request_id, None) + + return False, None + + # ============================== + # Static methods + # ============================== + + @staticmethod + def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]: + # Regular expression to match the string hostname and integer port + if is_prefill: + pattern = r"___decode_addr_(.*):(\d+)" + else: + pattern = r"___prefill_addr_(.*):(\d+)___" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + ip = match.group(1) + port = int(match.group(2)) + + return ip, port + raise ValueError(f"Request id {request_id} does not contain hostname and port") + + @staticmethod + def check_tensors_except_dim(tensor1, tensor2, dim): + shape1 = tensor1.size() + shape2 = tensor2.size() + + if len(shape1) != len(shape2) or not all( + s1 == s2 for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim + ): + raise NotImplementedError( + "Currently, only symmetric TP is supported. Asymmetric TP, PP," + "and others will be supported in future PRs." + ) diff --git a/vllm_fl/distributed/kv_transfer/p2p_flagcx_engine.py b/vllm_fl/distributed/kv_transfer/p2p_flagcx_engine.py new file mode 100644 index 00000000..cc862fd9 --- /dev/null +++ b/vllm_fl/distributed/kv_transfer/p2p_flagcx_engine.py @@ -0,0 +1,391 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +import ctypes +import json +import logging +import os +import sys +import threading +import time +from collections import deque +from typing import Any + +import msgpack +import torch +import zmq + +from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( + P2pNcclEngine, + SendQueueItem, + set_p2p_nccl_context, + DEFAULT_MEM_POOL_SIZE_GB, +) +from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( + TensorMemoryPool, +) +from vllm.utils.network_utils import get_ip +from vllm.utils.torch_utils import current_stream + +# Load FlagCX wrapper from FLAGCX_PATH +_flagcx_path = os.getenv('FLAGCX_PATH') +if _flagcx_path and os.path.isdir(_flagcx_path): + if _flagcx_path not in sys.path: + sys.path.append(_flagcx_path) + +from plugin.interservice.flagcx_wrapper import ( + FLAGCXLibrary, + buffer_type, + flagcxComm_t, + flagcxDataTypeEnum, + flagcxUniqueId, +) + +logger = logging.getLogger(__name__) + + +class P2pFlagcxEngine(P2pNcclEngine): + """P2P engine using FlagCX for KV cache transfer instead of native NCCL. + + Subclasses P2pNcclEngine and overrides only the communication-backend + specific methods (__init__, create_connect, listen_for_requests, send, + recv). All backend-agnostic logic (ZMQ signaling, send/recv queues, + memory pool, threading) is inherited unchanged. + """ + + def __init__( + self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: str | None = None, + ) -> None: + # NOTE: we intentionally do NOT call super().__init__() because it + # loads NCCLLibrary and starts a listener thread bound to NCCL. + # Instead, we replicate the init logic with FlagCX as the backend. + + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + + # --- FlagCX instead of NCCL --- + if library_path is None: + flagcx_path = os.getenv('FLAGCX_PATH') + if flagcx_path: + library_path = os.path.join( + flagcx_path, "build/lib/libflagcx.so" + ) + self.flagcx = FLAGCXLibrary(library_path) + + # --- rest is identical to P2pNcclEngine.__init__ --- + if not hostname: + hostname = get_ip() + port = int(self.config.kv_port) + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + self.zmq_address = f"{self._hostname}:{self._port}" + + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + self.http_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + http_port = self.config.get_from_extra_config("http_port", None) + if http_port is None: + example_cfg = { + "kv_connector": "P2pNcclConnector", + "kv_connector_extra_config": {"http_port": 8000}, + } + example = ( + f"--port=8000 --kv-transfer-config=" + f"'{json.dumps(example_cfg)}'" + ) + raise ValueError( + "kv_connector_extra_config.http_port is required. " + f"Example: {example}" + ) + self.http_address = f"{self._hostname}:{http_port}" + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + + self.send_stream = torch.cuda.Stream() + self.recv_stream = torch.cuda.Stream() + + mem_pool_size_gb = float( + self.config.get_from_extra_config( + "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB + ) + ) + self.pool = TensorMemoryPool( + max_block_size=int(mem_pool_size_gb * 1024**3) + ) + + self.send_type = self.config.get_from_extra_config( + "send_type", "PUT_ASYNC" + ) + if self.send_type == "GET": + self.send_store: dict[str, torch.Tensor] = {} + else: + self.send_queue: deque[SendQueueItem] = deque() + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread( + target=self.send_async, daemon=True + ) + self._send_thread.start() + + self.recv_store: dict[str, Any] = {} + self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {} + self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} + self.socks: dict[str, Any] = {} + self.comms: dict[str, Any] = {} + + self.buffer_size = 0 + self.buffer_size_threshold = float(self.config.kv_buffer_size) + + self.nccl_num_channels = self.config.get_from_extra_config( + "nccl_num_channels", "8" + ) + + self._listener_thread = threading.Thread( + target=self.listen_for_requests, daemon=True + ) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread( + target=self.ping, daemon=True + ) + self._ping_thread.start() + + logger.warning( + "💯P2pFlagcxEngine init, rank:%d, local_rank:%d, " + "http_address:%s, zmq_address:%s, proxy_address:%s, " + "send_type:%s, buffer_size_threshold:%.2f, " + "nccl_num_channels:%s", + self.rank, + self.local_rank, + self.http_address, + self.zmq_address, + self.proxy_address, + self.send_type, + self.buffer_size_threshold, + self.nccl_num_channels, + ) + + # ------------------------------------------------------------------ + # Connection establishment (FlagCX unique-id / comm-init) + # ------------------------------------------------------------------ + + def create_connect(self, remote_address: str | None = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info( + "👋comm exists, remote_address:%s, comms:%s", + remote_address, + self.comms, + ) + return sock, self.comms[remote_address] + + # FlagCX: get unique id and serialize for ZMQ + unique_id_ptr = self.flagcx.flagcxGetUniqueId() + unique_id = unique_id_ptr.contents + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.accelerator.device_index(self.device.index): + rank = 0 + with set_p2p_nccl_context(self.nccl_num_channels): + comm: flagcxComm_t = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank + ) + self.comms[remote_address] = (comm, rank) + logger.info( + "🤝flagcxCommInitRank Success, %s👉%s, MyRank:%s", + self.zmq_address, + remote_address, + rank, + ) + + return self.socks[remote_address], self.comms[remote_address] + + # ------------------------------------------------------------------ + # Listener thread (FlagCX unique-id deserialization) + # ------------------------------------------------------------------ + + def listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket not in socks: + continue + + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + + if data["cmd"] == "NEW": + # FlagCX: reconstruct unique id from bytes + unique_id = self.flagcx.unique_id_from_bytes( + bytes(data["unique_id"]) + ) + with torch.accelerator.device_index(self.device.index): + rank = 1 + with set_p2p_nccl_context(self.nccl_num_channels): + comm: flagcxComm_t = self.flagcx.flagcxCommInitRank( + 2, ctypes.byref(unique_id), rank + ) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝flagcxCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, + remote_address.decode(), + rank, + ) + + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + with torch.cuda.stream(self.recv_stream): + tensor = torch.empty( + data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device, + ) + self.router_socket.send_multipart( + [remote_address, b"0"] + ) + comm, rank = self.comms[remote_address.decode()] + self.recv(comm, tensor, rank ^ 1, self.recv_stream) + tensor_size = tensor.element_size() * tensor.numel() + if ( + self.buffer_size + tensor_size + > self.buffer_size_threshold + ): + addr = self.pool.store_tensor(tensor) + tensor = (addr, tensor.dtype, tensor.shape) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s, addr:%d", + self.zmq_address, + remote_address.decode(), + data, + addr, + ) + else: + self.buffer_size += tensor_size + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"] + ) + tensor = None + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, " + "%s👈%s, data:%s", + self.zmq_address, + remote_address.decode(), + data, + ) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.have_received_tensor_id(tensor_id) + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", ""), + } + self.send_store[tensor_id] = tensor + self.have_sent_tensor_id(tensor_id) + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)] + ) + + if data["ret"] == 0: + comm, rank = self.comms[remote_address.decode()] + self.send( + comm, + tensor.to(self.device), + rank ^ 1, + self.send_stream, + ) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, + data, + ) + + # ------------------------------------------------------------------ + # Data transfer (FlagCX send / recv with stream adaptation) + # ------------------------------------------------------------------ + + def send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), + dst, + comm, + flagcx_stream, + ) + self.flagcx.adaptor_stream_free(flagcx_stream) + stream.synchronize() + + def recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this flagcx communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + + flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + self.flagcx.flagcxRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + flagcxDataTypeEnum.from_torch(tensor.dtype), + src, + comm, + flagcx_stream, + ) + self.flagcx.adaptor_stream_free(flagcx_stream) + stream.synchronize()