diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index cce71f5d16..174f5bc775 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -201,6 +201,7 @@ def load_model_in_plugin_mode( load_fused_expert_weights_fn=None, spec_decode: bool = False, hf_config_override: AutoConfig | None = None, + model_name_or_path_override: str | None = None, ) -> set[str]: # during loading model, the outplace operation may consume more @@ -215,7 +216,9 @@ def _empty_cache(): assert ( config.plugin_config is not None and config.plugin_config.is_plugin_mode ), "ATOM is not running in plugin mode" - if config.plugin_config.is_vllm: + if model_name_or_path_override is not None: + model_name_or_path = model_name_or_path_override + elif config.plugin_config.is_vllm: model_name_or_path = config.plugin_config.model_config.model elif config.plugin_config.is_sglang: model_name_or_path = config.plugin_config.model_config.model_path diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index d2998efa81..e0dd792022 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -2072,16 +2072,29 @@ def __init__( self.fuse_input_norm_quant = False self.fuse_ar_input_norm = ENABLE_ALLREDUCE_RMSNORM_FUSION if quant_config is not None and ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION: + # While self.quant_dtype is resolved from the *layer* prefix, model + # checkpoints can keep the MLA a-proj in unquantized form via + # `exclude`, like bf16 in Kimi-K2.6-MXFP4. So only fuse when the + # attn a-proj is also quantized, or otherwise the fusion would + # result in GEMM on packed FP4 activation with bf16 weights, and + # lead to un-multipliable shapes. + attn_quant_dtype = self.self_attn.quant_dtype enable_fp8_input_norm_quant = ( - self.quant_dtype == dtypes.fp8 and use_triton_gemm() + self.quant_dtype == dtypes.fp8 + and attn_quant_dtype == dtypes.fp8 + and use_triton_gemm() ) - enable_fp4_input_norm_quant = self.quant_dtype == dtypes.fp4x2 and ( - use_triton_gemm() - or _enable_non_triton_global_mxfp4_input_norm_quant( - config, - quant_config, - self.quant_dtype, - is_mtp_block, + enable_fp4_input_norm_quant = ( + self.quant_dtype == dtypes.fp4x2 + and attn_quant_dtype == dtypes.fp4x2 + and ( + use_triton_gemm() + or _enable_non_triton_global_mxfp4_input_norm_quant( + config, + quant_config, + self.quant_dtype, + is_mtp_block, + ) ) ) if enable_fp8_input_norm_quant or enable_fp4_input_norm_quant: diff --git a/atom/models/llama.py b/atom/models/llama.py index 349ebc5661..4f6e695be9 100644 --- a/atom/models/llama.py +++ b/atom/models/llama.py @@ -514,3 +514,14 @@ def compute_logits( ) -> Optional[torch.Tensor]: logits = self.lm_head(hidden_states) return logits + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Default Eagle3 aux hidden-state layer ids: early / middle / late of + the target model. Aligned with vLLM's default (see + vllm/model_executor/models/llama.py). + """ + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) diff --git a/atom/plugin/vllm/attention/metadata.py b/atom/plugin/vllm/attention/metadata.py index d411514d48..d83a95092d 100644 --- a/atom/plugin/vllm/attention/metadata.py +++ b/atom/plugin/vllm/attention/metadata.py @@ -408,8 +408,12 @@ def __init__( self.parallel_config = config.parallel_config self.cache_config = config.cache_config - self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) - self.head_dim = self.model_config.get_head_size() + # For EAGLE3 mha draft with mla target, model_config describes the mla target, + # but this metadata builder servers the mha draft's own kv cache group. So derive + # the kv geometry from the kv_cache_spec, which in non-EAGLE case agrees with the + # model_config. + self.num_heads_kv = kv_cache_spec.num_kv_heads + self.head_dim = kv_cache_spec.head_size self.block_size = kv_cache_spec.block_size self.aot_sliding_window: tuple[int, int] | None = None @@ -780,7 +784,7 @@ def build_for_cudagraph_capture( class AiterMlaMetadataBuilderForVllm(MLACommonMetadataBuilder): """vLLM-only dense MLA metadata builder.""" - _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold = 1 query_len_support = QueryLenSupport.UNIFORM diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 5413614d5f..4131a492df 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -44,6 +44,17 @@ "DeepSeekMTPModel", "Glm4MoeMTPModel", } +_EAGLE3_DRAFT_ARCH_TO_ATOM_ARCH: dict[str, str] = { + # vLLM/HF draft arch name: ATOM server-mode draft class + "Eagle3LlamaForCausalLM": "Eagle3LlamaModel", + "LlamaForCausalLMEagle3": "Eagle3LlamaModel", + "Eagle3DeepseekV2ForCausalLM": "Eagle3DeepseekMLAModel", + "Eagle3DeepseekV3ForCausalLM": "Eagle3DeepseekMLAModel", +} +_EAGLE3_ATOM_DRAFT_ARCHS: set[str] = { + "Eagle3LlamaModel", + "Eagle3DeepseekMLAModel", +} # DeepSeek-V4 is a native ATOM model whose forward reads ATOM's own forward # context (not vLLM's). It needs the V4 proxy-cache bridge wired in the plugin # wrapper (register at init, bind + enter context per forward); see `forward`. @@ -131,12 +142,26 @@ def _maybe_set_v4_expert_dtype(atom_config, vllm_config) -> None: "KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration_", "MiniMaxM2ForCausalLM": "atom.models.minimax_m2:MiniMaxM2ForCausalLM", "DeepseekV4ForCausalLM": "atom.plugin.vllm.models.deepseek_v4:DeepseekV4ForCausalLM", + "Eagle3LlamaModel": "atom.models.eagle3_llama:Eagle3LlamaModel", + "Eagle3DeepseekMLAModel": "atom.models.eagle3_deepseek_mla:Eagle3DeepseekMLAModel", } +def _normalize_atom_model_arch(model_arch: str) -> str: + return _EAGLE3_DRAFT_ARCH_TO_ATOM_ARCH.get(model_arch, model_arch) + + +def _is_eagle3_draft_arch(model_arch: str | None) -> bool: + return ( + model_arch in _EAGLE3_DRAFT_ARCH_TO_ATOM_ARCH + or model_arch in _EAGLE3_ATOM_DRAFT_ARCHS + ) + + def _get_atom_model_cls(model_arch: str) -> type: - if model_arch is not None and model_arch in _ATOM_MODEL_CLASSES: - model_ref = _ATOM_MODEL_CLASSES[model_arch] + normalized_arch = _normalize_atom_model_arch(model_arch) + if normalized_arch is not None and normalized_arch in _ATOM_MODEL_CLASSES: + model_ref = _ATOM_MODEL_CLASSES[normalized_arch] else: raise ValueError(f"The {model_arch} is not supported by ATOM OOT backend") @@ -244,8 +269,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): _set_framework_backbone("vllm") - self.config = vllm_config.model_config.hf_config - self.text_config = self.config.get_text_config() self.cache_config = vllm_config.cache_config self.device_config = vllm_config.device_config self.model_config = vllm_config.model_config @@ -261,14 +284,42 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vllm_config = vllm_config self.is_mtp = False + self.is_eagle3 = False speculative_config = getattr(vllm_config, "speculative_config", None) if speculative_config is not None: spec_method = speculative_config.method self.is_mtp = spec_method == "mtp" + self.is_eagle3 = spec_method == "eagle3" main_model_arch = vllm_config.model_config.architectures[0] - model_arch = _select_model_arch(vllm_config) - self.is_mtp_draft_model = self.is_mtp and model_arch != main_model_arch + selected_model_arch = _select_model_arch(vllm_config) + # Normalize vLLM or HF draft architecture to ATOM server-mode draft class, + # pass through for non-draft models + model_arch = _normalize_atom_model_arch(selected_model_arch) + draft_model_config = getattr(speculative_config, "draft_model_config", None) + draft_hf_config = getattr(draft_model_config, "hf_config", None) + self.is_mtp_draft_model = self.is_mtp and selected_model_arch != main_model_arch + self.is_eagle3_draft_model = ( + self.is_eagle3 + and selected_model_arch != main_model_arch + and _is_eagle3_draft_arch(selected_model_arch) + ) + self.is_spec_draft_model = self.is_mtp_draft_model or self.is_eagle3_draft_model + + if self.is_eagle3_draft_model and draft_hf_config is None: + raise ValueError("EAGLE3 draft model config is missing hf_config") + + self.config = ( + draft_hf_config + if self.is_eagle3_draft_model + else vllm_config.model_config.hf_config + ) + self.text_config = ( + self.config.get_text_config() + if hasattr(self.config, "get_text_config") + else self.config + ) + if self.is_mtp_draft_model: # Generate separate config for main model and draft model to make sure # that draft model has its own compilation config rather than carried @@ -276,11 +327,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): main_atom_config = get_current_atom_config() self.atom_config = _generate_atom_config_from_vllm_config(vllm_config) self.atom_config.hf_config = main_atom_config.hf_config + elif self.is_eagle3_draft_model: + self.atom_config = _generate_atom_config_from_vllm_config(vllm_config) + self.atom_config.hf_config = draft_hf_config else: self.atom_config = generate_atom_config_for_plugin_mode(vllm_config) # root HF config so --hf-overrides survive without losing multimodal # sub-configs such as Kimi-K2.5's vision_config/text_config. self.atom_config.hf_config = self.config + self.vllm_model_arch = selected_model_arch self.model_arch = model_arch logger.info( "ATOM vLLM hf config overrides: use_index_cache=%s, index_topk_freq=%s, " @@ -325,13 +380,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.atom_config.quant_config.apply_default_exclude_layers(default_excludes) logger.info(f"Construct ATOM model {model_arch} for vLLM plugin mode") - if self.is_mtp_draft_model: + if self.is_spec_draft_model: # Draft model's layers read get_current_atom_config() to register their # static_forward_context, so swap out the global atom_config temporarily # with the draft model's atom_config so that the correct forward context # can be registered with use_custom_atom_config(self.atom_config): - self.model = model_cls(self.atom_config) + if self.is_eagle3_draft_model: + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + logger.info( + "Construct EAGLE3 draft with layer_offset=%s", + target_layer_num, + ) + self.model = model_cls( + self.atom_config, + layer_offset=target_layer_num, + ) + else: + self.model = model_cls(self.atom_config) else: self.model = model_cls(self.atom_config) @@ -348,7 +416,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if model_arch in _MTP_MASK_INPUT_ARCH: self._adapt_mtp_layers_for_vllm() - if self.is_mtp: + if self.is_eagle3_draft_model: + self._enable_eagle3_draft_interface() + elif self.is_eagle3 and self._eagle3_uses_aux_hidden_state(): + self._enable_eagle3_target_interface() + if self.is_mtp or self.is_eagle3: # Mirror nested attributes required by vLLM speculative decoding. self._expose_spec_decode_attrs() @@ -418,6 +490,9 @@ def _expose_spec_decode_attrs(self) -> None: # (2) Propagate: future writes on the outer model sync to the inner # model. We create a one-off subclass so the hook only affects # this particular draft-model instance, not the base class. + # Create the one-off subclass only once + if getattr(model, "_atom_vllm_shared_attr_sync_patched", False): + return shared = self._WEIGHT_SHARED_ATTRS base_setattr = model.__class__.__setattr__ @@ -426,10 +501,16 @@ def _syncing_setattr(self_model, name, value): if name in shared and hasattr(inner, name): base_setattr(inner, name, value) + base_setattr(model, "_atom_vllm_shared_attr_sync_patched", True) + # Make the one-off subclass report its actual module instead of the + # base wrapper's model.__class__ = type( model.__class__.__name__, (model.__class__,), - {"__setattr__": _syncing_setattr}, + { + "__module__": model.__class__.__module__, + "__setattr__": _syncing_setattr, + }, ) def _register_indexer_caches_with_vllm(self): @@ -545,6 +626,92 @@ def masked_forward( return masked_forward + def _eagle3_uses_aux_hidden_state(self) -> bool: + vllm_spec_config = getattr(self.vllm_config, "speculative_config", None) + if getattr(vllm_spec_config, "method", None) != "eagle3": + return False + draft_model_config = getattr(vllm_spec_config, "draft_model_config", None) + hf_config = getattr(draft_model_config, "hf_config", None) + eagle_config = getattr(hf_config, "eagle_config", None) + if isinstance(eagle_config, dict): + return eagle_config.get("use_aux_hidden_state", True) + return True + + def _enable_eagle3_target_interface(self) -> None: + """Expose vLLM's SupportsEagle3 target surface by bridging to the inner + ATOM model's server-mode aux_hidden_state interface. + ATOM target models follow the server-mode convention, exposing + `set_aux_hidden_state_layers` and `get_eagle3_aux_hidden_state_layers`. + vLLM's SupportsEagle3 instead calls `set_aux_hidden_state_layers` and + `get_eagle3_default_aux_hidden_state_layers`. + """ + model = self.model + if not ( + callable(getattr(model, "set_aux_hidden_state_layers", None)) + and callable(getattr(model, "get_eagle3_aux_hidden_state_layers", None)) + ): + raise RuntimeError( + f"Model {self.model_arch} cannot serve as an EAGLE3 target: it " + "does not expose the ATOM server-mode aux-hidden-state interface " + "(set_aux_hidden_state_layers / get_eagle3_aux_hidden_state_layers)." + ) + self.supports_eagle3 = True + self.has_own_lm_head = False + self.has_own_embed_tokens = False + self.set_aux_hidden_state_layers = model.set_aux_hidden_state_layers + self.get_eagle3_default_aux_hidden_state_layers = ( + self._resolve_eagle3_aux_hidden_state_layers + ) + + def _resolve_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + # Following ATOM server mode, perfer the draft's configured IDs that + # are already resolved from the possibly nested eagle_config by ATOM's + # SpeculativeConfig.__post_init__, and fall back to the target model's + # architecture default + spec_config = getattr(self.atom_config, "speculative_config", None) + aux_ids = list(getattr(spec_config, "eagle3_aux_layer_ids", None) or []) + if aux_ids: + return tuple(aux_ids) + return tuple(self.model.get_eagle3_aux_hidden_state_layers()) + + def _enable_eagle3_draft_interface(self) -> None: + # Expose vLLM's EAGLE3 draft `combine_hidden_states` by forwarding it to + # the inner ATOM draft model + model = self.model + if not callable(getattr(model, "combine_hidden_states", None)): + raise RuntimeError( + f"Model {self.model_arch} cannot serve as an EAGLE3 draft: it " + "does not implement combine_hidden_states()." + ) + self.has_own_lm_head = False + self.has_own_embed_tokens = False + self.combine_hidden_states = model.combine_hidden_states + self._maybe_index_draft_attn_layer() + + def _maybe_index_draft_attn_layer(self) -> None: + # vLLM's bind_kv_cache calls extract_layer_index which asserts that + # each kv cache layer name contains only one integer. ATOM's + # Eagle3LlamaModel names its decoder layer as "midlayer", so prefix + # it with "layers.0." so that vLLM's assertion can pass + static_forward_context = self.vllm_compilation_config.static_forward_context + + for _name, module in self.model.named_modules(): + old_name = getattr(module, "layer_name", None) + if old_name is None or any(p.isdigit() for p in old_name.split(".")): + continue + new_name = f"layers.0.{old_name}" + if new_name in static_forward_context: + raise ValueError( + f"Cannot re-key draft attention layer {old_name} to " + f"{new_name}; name already registered." + ) + static_forward_context[new_name] = static_forward_context.pop(old_name) + module.layer_name = new_name + logger.info( + f"Re-keyed EAGLE3 draft attention layer {old_name} to " + f"{new_name} for vLLM to extract a layer index" + ) + def forward( self, input_ids: torch.Tensor | None, @@ -572,7 +739,22 @@ def forward( ] buf[: positions.numel()].copy_(positions) - if self._is_deepseek_v4: + if self.is_eagle3_draft_model: + if inputs_embeds is not None: + raise NotImplementedError( + "ATOM EAGLE3 draft wrappers do not support multimodal " + "inputs_embeds in vLLM plugin mode yet." + ) + if "hidden_states" not in model_kwargs: + raise ValueError("EAGLE3 draft forward requires hidden_states.") + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + hidden_states=model_kwargs["hidden_states"], + ) + if not isinstance(hidden_states, tuple): + hidden_states = (hidden_states, hidden_states) + elif self._is_deepseek_v4: # DeepSeek-V4 is a native ATOM model: it reads ATOM's own forward # context and takes a native (input_ids, positions) forward — vLLM's # generic call contract (intermediate_tensors/inputs_embeds) does not @@ -614,7 +796,6 @@ def forward( inputs_embeds=inputs_embeds, **model_kwargs, ) - if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -633,6 +814,7 @@ def load_weights( "Glm4MoeMTPModel", } draft_hf_config = None + draft_model_path = None if is_mtp_draft_model: draft_model_config = getattr( getattr(self.atom_config, "speculative_config", None), @@ -643,6 +825,20 @@ def load_weights( draft_hf_config = getattr( draft_model_config, "hf_config", draft_model_config ) + if self.is_eagle3_draft_model: + # EAGLE3 drafts are standalone checkpoints, so we need both the draft + # hf_config and the draft checkpoint path + spec_config = getattr(self.vllm_config, "speculative_config", None) + draft_model_config = getattr(spec_config, "draft_model_config", None) + if draft_model_config is not None: + draft_hf_config = getattr( + draft_model_config, "hf_config", draft_model_config + ) + draft_model_path = getattr( + draft_model_config, "model", None + ) or getattr(spec_config, "model", None) + if not draft_model_path: + raise ValueError("EAGLE3 draft model path is missing.") loaded_weights_record = load_model_in_plugin_mode( model=self.model, @@ -650,7 +846,17 @@ def load_weights( prefix="model.", spec_decode=is_mtp_draft_model, hf_config_override=draft_hf_config, + model_name_or_path_override=draft_model_path, ) + if self.is_eagle3_draft_model: + self.has_own_embed_tokens = any( + "embed_tokens" in name for name in loaded_weights_record + ) + self.has_own_lm_head = any( + "lm_head" in name for name in loaded_weights_record + ) + self.model.has_own_embed_tokens = self.has_own_embed_tokens + self.model.has_own_lm_head = self.has_own_lm_head return loaded_weights_record def compute_logits( diff --git a/atom/plugin/vllm/models/kimi_k25.py b/atom/plugin/vllm/models/kimi_k25.py index 6d649d89b1..86f5023f47 100644 --- a/atom/plugin/vllm/models/kimi_k25.py +++ b/atom/plugin/vllm/models/kimi_k25.py @@ -157,6 +157,13 @@ def make_empty_intermediate_tensors( } ) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 033b9dc706..e9dcbe02cd 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -35,6 +35,10 @@ "KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration", "MiniMaxM2ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, "DeepseekV4ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "Eagle3LlamaForCausalLM": ATOM_CAUSAL_LM_MODEL_WRAPPER, + "LlamaForCausalLMEagle3": ATOM_CAUSAL_LM_MODEL_WRAPPER, + "Eagle3DeepseekV2ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "Eagle3DeepseekV3ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } diff --git a/atom/plugin/vllm/spec_decode_patch.py b/atom/plugin/vllm/spec_decode_patch.py index 7ec595b648..abc95ca1bf 100644 --- a/atom/plugin/vllm/spec_decode_patch.py +++ b/atom/plugin/vllm/spec_decode_patch.py @@ -4,6 +4,310 @@ logger = logging.getLogger("atom") +def _patch_eagle3_model_type_checks() -> None: + # vLLM's V1 EAGLE proposer SpecDecodeBaseProposer.propose() has an explicit + # isinstance() check for native vLLM EAGLE3 model classes before calling + # `combine_hidden_states()`. ATOM's vLLM plugin mode provides the same behavior + # through the ATOMModelBase wrapper, so patch the type checks to accept the + # ATOMModelBase wrapper + try: + from atom.plugin.vllm.model_wrapper import ATOMModelBase + import vllm.v1.spec_decode.llm_base_proposer as llm_base_proposer + except Exception: + logger.warning( + "vLLM plugin: failed to patch vLLM V1 EAGLE3 proposer type checks. " + "This can happen if you are using an in-compatible vLLM version. " + "Please make sure that the correct vLLM version is installed." + ) + return + + if getattr(llm_base_proposer, "_atom_eagle3_model_types_patched", False): + return + + # Supported archs in vLLM's `llm_base_proposer.py` + for name in ("Eagle3LlamaForCausalLM", "Eagle3DeepseekV2ForCausalLM"): + original = getattr(llm_base_proposer, name, None) + if original is None: + continue + if isinstance(original, tuple): + widened = (*original, ATOMModelBase) + else: + widened = (original, ATOMModelBase) + setattr(llm_base_proposer, name, widened) + + setattr(llm_base_proposer, "_atom_eagle3_model_types_patched", True) + logger.info("ATOM plugin: patched vLLM EAGLE3 proposer type checks.") + + +def _get_attn_backend_block_size(backend) -> int: + supported = backend.get_supported_kernel_block_sizes() + get_preferred = getattr(backend, "get_preferred_block_size", None) + if get_preferred is None: + return supported[0] + return get_preferred(supported[0]) + + +@functools.cache +def _get_mla_block_size() -> int: + from atom.plugin.vllm.attention.backend import AiterMlaBackendForVllm + + return _get_attn_backend_block_size(AiterMlaBackendForVllm) + + +@functools.cache +def _get_mha_block_size() -> int: + from atom.plugin.vllm.attention.backend import AiterMhaBackendForVllm + + return _get_attn_backend_block_size(AiterMhaBackendForVllm) + + +def _spec_has_heterogeneous_mla_mha_backend(kv_cache_spec) -> bool: + try: + from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec + except Exception: + return False + + has_mla = False + has_non_mla_attn = False + for spec in kv_cache_spec.values(): + if isinstance(spec, MLAAttentionSpec): + has_mla = True + elif isinstance(spec, AttentionSpec): + has_non_mla_attn = True + return has_mla and has_non_mla_attn + + +def _split_mla_and_mha_layers(kv_cache_spec): + from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec + + mla_layers = {} + mha_layers = {} + for name, spec in kv_cache_spec.items(): + if isinstance(spec, MLAAttentionSpec): + mla_layers[name] = spec + elif isinstance(spec, AttentionSpec): + mha_layers[name] = spec + else: + raise NotImplementedError( + "The heterogeneous EAGLE3 KV pool only supports MLA target with " + f"MHA draft, but got unexpected spec {type(spec).__name__} for " + f"layer {name}." + ) + return mla_layers, mha_layers + + +def _build_heterogeneous_kv_cache_groups(kv_cache_spec): + # Build separate groups for MLA and MHA with distinct block sizes and page sizes + # to bypass page size unification. + from vllm.v1.kv_cache_interface import KVCacheGroupSpec, UniformTypeKVCacheSpecs + + mla_layers, mha_layers = _split_mla_and_mha_layers(kv_cache_spec) + assert mla_layers, "Heterogeneous EAGLE3 requires at least 1 MLA layer" + assert mha_layers, "Heterogeneous EAGLE3 requires at least 1 MHA layer" + + # Use UniformTypeKVCacheSpecs so per-layer page sizes are preserved even + # if individual MLA layers differ, though they should be identical. + mla_specs = { + name: spec.copy_with_new_block_size(_get_mla_block_size()) + for name, spec in mla_layers.items() + } + mla_uniform = UniformTypeKVCacheSpecs.from_specs(mla_specs) + assert mla_uniform is not None, ( + "Failed to build UniformTypeKVCacheSpecs for MLA target layers" + ) + mla_group = KVCacheGroupSpec( + layer_names=list(mla_specs.keys()), + kv_cache_spec=mla_uniform, + ) + + mha_specs = [ + spec.copy_with_new_block_size(_get_mha_block_size()) + for spec in mha_layers.values() + ] + merged_mha = mha_specs[0].merge(mha_specs) + mha_group = KVCacheGroupSpec( + layer_names=list(mha_layers.keys()), + kv_cache_spec=merged_mha, + ) + + return [mla_group, mha_group] + + +def _groups_are_heterogeneous_mla_mha(kv_cache_groups) -> bool: + try: + from vllm.v1.kv_cache_interface import ( + AttentionSpec, + MLAAttentionSpec, + UniformTypeKVCacheSpecs, + ) + except Exception: + logger.warning( + "vLLM plugin: failed to recognize ATOM heterogeneous EAGLE3 KV pool. " + "This can happen if you are using an in-compatible vLLM version. " + "Please make sure that the correct vLLM version is installed." + ) + return False + + if len(kv_cache_groups) != 2: + return False + + def _is_mla(group): + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + specs = list(spec.kv_cache_specs.values()) + return bool(specs) and all(isinstance(s, MLAAttentionSpec) for s in specs) + return isinstance(spec, MLAAttentionSpec) + + def _is_mha(group): + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + specs = list(spec.kv_cache_specs.values()) + return bool(specs) and all( + isinstance(s, AttentionSpec) and not isinstance(s, MLAAttentionSpec) + for s in specs + ) + return isinstance(spec, AttentionSpec) and not isinstance( + spec, MLAAttentionSpec + ) + + g0, g1 = kv_cache_groups + return (_is_mla(g0) and _is_mha(g1)) or ( + _is_mla(g1) and _is_mha(g0) + ) + + +def _build_heterogeneous_kv_cache_config_from_groups( + vllm_config, kv_cache_groups, available_memory +): + # Custom kv cache allocator for mixed mla/mha target/draft layout. + # Allocates a single number of blocks for all layers of both groups + from vllm.v1.core.kv_cache_utils import may_override_num_blocks + from vllm.v1.kv_cache_interface import ( + KVCacheConfig, + KVCacheTensor, + UniformTypeKVCacheSpecs, + ) + + def _iter_layer_specs(group): + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + for layer_name, layer_spec in spec.kv_cache_specs.items(): + yield layer_name, layer_spec + else: + for layer_name in group.layer_names: + yield layer_name, spec + + bytes_per_block_all_layers = 0 + for group in kv_cache_groups: + for _layer_name, layer_spec in _iter_layer_specs(group): + bytes_per_block_all_layers += layer_spec.page_size_bytes + + assert bytes_per_block_all_layers > 0, "Zero per-block bytes" + num_blocks = available_memory // bytes_per_block_all_layers + num_blocks = max(num_blocks, 0) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) + + kv_cache_tensors = [] + for group in kv_cache_groups: + for layer_name, layer_spec in _iter_layer_specs(group): + kv_cache_tensors.append( + KVCacheTensor( + size=layer_spec.page_size_bytes * num_blocks, + shared_by=[layer_name], + ) + ) + + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=kv_cache_groups, + ) + + +def _heterogeneous_max_memory_usage_bytes(vllm_config, kv_cache_groups): + # Max bytes needed for both groups to hold max_model_len tokens + from vllm.utils.math_utils import cdiv + from vllm.v1.kv_cache_interface import UniformTypeKVCacheSpecs + + max_model_len = vllm_config.model_config.max_model_len + total = 0 + for group in kv_cache_groups: + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + block_size = spec.block_size + per_block_bytes = sum( + s.page_size_bytes for s in spec.kv_cache_specs.values() + ) + else: + block_size = spec.block_size + per_block_bytes = spec.page_size_bytes * len(group.layer_names) + num_blocks_for_len = cdiv(max_model_len, block_size) + total += num_blocks_for_len * per_block_bytes + return total + + +def _patch_heterogeneous_eagle3_kv_cache() -> None: + """Patch vLLM KV-cache grouping/allocation for heterogeneous KV cache so + MLA target can coexist with an MHA EAGLE3 draft. + Only MLA target and MHA draft combination is supported for now. + """ + try: + import vllm.v1.core.kv_cache_utils as vllm_kv_cache_utils + except Exception: + logger.warning( + "ATOM plugin: failed to import vLLM kv_cache_utils; cannot enable " + "MLA target with MHA EAGLE3 draft. This can happen with " + "incompatible vLLM version." + ) + return + + if getattr(vllm_kv_cache_utils, "_atom_heterogeneous_eagle3_patched", False): + return + + orig_get_groups = vllm_kv_cache_utils.get_kv_cache_groups + orig_config_from_groups = vllm_kv_cache_utils.get_kv_cache_config_from_groups + orig_max_mem = vllm_kv_cache_utils._max_memory_usage_bytes_from_groups + + @functools.wraps(orig_get_groups) + def patched_get_kv_cache_groups(vllm_config, kv_cache_spec): + if getattr( + vllm_config.model_config, "use_mla", False + ) and _spec_has_heterogeneous_mla_mha_backend(kv_cache_spec): + logger.info( + "ATOM plugin: using heterogeneous KV cache layout - MLA target " + "and MHA EAGLE3 draft - with separate per-group pools." + ) + return _build_heterogeneous_kv_cache_groups(kv_cache_spec) + return orig_get_groups(vllm_config, kv_cache_spec) + + @functools.wraps(orig_config_from_groups) + def patched_get_kv_cache_config_from_groups( + vllm_config, kv_cache_groups, available_memory + ): + if _groups_are_heterogeneous_mla_mha(kv_cache_groups): + return _build_heterogeneous_kv_cache_config_from_groups( + vllm_config, kv_cache_groups, available_memory + ) + return orig_config_from_groups(vllm_config, kv_cache_groups, available_memory) + + @functools.wraps(orig_max_mem) + def patched_max_memory_usage_bytes_from_groups(vllm_config, kv_cache_groups): + if _groups_are_heterogeneous_mla_mha(kv_cache_groups): + return _heterogeneous_max_memory_usage_bytes(vllm_config, kv_cache_groups) + return orig_max_mem(vllm_config, kv_cache_groups) + + vllm_kv_cache_utils.get_kv_cache_groups = patched_get_kv_cache_groups + vllm_kv_cache_utils.get_kv_cache_config_from_groups = patched_get_kv_cache_config_from_groups + vllm_kv_cache_utils._max_memory_usage_bytes_from_groups = ( + patched_max_memory_usage_bytes_from_groups + ) + vllm_kv_cache_utils._atom_heterogeneous_eagle3_patched = True + logger.info( + "ATOM plugin: patched vLLM KV-cache grouping/allocation for " + "MLA target with MHA EAGLE3 speculative decoding." + ) + + def apply_vllm_spec_decode_patch() -> None: """Patch vLLM speculative decoding for ATOM metadata compatibility.""" from atom.plugin.vllm.attention.metadata import ( @@ -17,6 +321,9 @@ def apply_vllm_spec_decode_patch() -> None: ) from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer + _patch_eagle3_model_type_checks() + _patch_heterogeneous_eagle3_kv_cache() + original_init = SpecDecodeBaseProposer.__init__ if getattr(original_init, "_atom_allowed_attn_types_patched", False): return