Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def load_model_in_plugin_mode(
load_fused_expert_weights_fn=None,
spec_decode: bool = False,
hf_config_override: AutoConfig | None = None,
model_name_or_path_override: str | None = None,
) -> set[str]:

# during loading model, the outplace operation may consume more
Expand All @@ -215,7 +216,9 @@ def _empty_cache():
assert (
config.plugin_config is not None and config.plugin_config.is_plugin_mode
), "ATOM is not running in plugin mode"
if config.plugin_config.is_vllm:
if model_name_or_path_override is not None:
model_name_or_path = model_name_or_path_override
elif config.plugin_config.is_vllm:
model_name_or_path = config.plugin_config.model_config.model
elif config.plugin_config.is_sglang:
model_name_or_path = config.plugin_config.model_config.model_path
Expand Down
29 changes: 21 additions & 8 deletions atom/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2072,16 +2072,29 @@ def __init__(
self.fuse_input_norm_quant = False
self.fuse_ar_input_norm = ENABLE_ALLREDUCE_RMSNORM_FUSION
if quant_config is not None and ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION:
# While self.quant_dtype is resolved from the *layer* prefix, model
# checkpoints can keep the MLA a-proj in unquantized form via
# `exclude`, like bf16 in Kimi-K2.6-MXFP4. So only fuse when the
# attn a-proj is also quantized, or otherwise the fusion would
# result in GEMM on packed FP4 activation with bf16 weights, and
# lead to un-multipliable shapes.
attn_quant_dtype = self.self_attn.quant_dtype
enable_fp8_input_norm_quant = (
self.quant_dtype == dtypes.fp8 and use_triton_gemm()
self.quant_dtype == dtypes.fp8
and attn_quant_dtype == dtypes.fp8
and use_triton_gemm()
)
enable_fp4_input_norm_quant = self.quant_dtype == dtypes.fp4x2 and (
use_triton_gemm()
or _enable_non_triton_global_mxfp4_input_norm_quant(
config,
quant_config,
self.quant_dtype,
is_mtp_block,
enable_fp4_input_norm_quant = (
self.quant_dtype == dtypes.fp4x2
and attn_quant_dtype == dtypes.fp4x2
and (
use_triton_gemm()
or _enable_non_triton_global_mxfp4_input_norm_quant(
config,
quant_config,
self.quant_dtype,
is_mtp_block,
)
)
)
if enable_fp8_input_norm_quant or enable_fp4_input_norm_quant:
Expand Down
11 changes: 11 additions & 0 deletions atom/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,14 @@ def compute_logits(
) -> Optional[torch.Tensor]:
logits = self.lm_head(hidden_states)
return logits

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Default Eagle3 aux hidden-state layer ids: early / middle / late of
the target model. Aligned with vLLM's default (see
vllm/model_executor/models/llama.py).
"""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
10 changes: 7 additions & 3 deletions atom/plugin/vllm/attention/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,12 @@ def __init__(
self.parallel_config = config.parallel_config
self.cache_config = config.cache_config

self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
self.head_dim = self.model_config.get_head_size()
# For EAGLE3 mha draft with mla target, model_config describes the mla target,
# but this metadata builder servers the mha draft's own kv cache group. So derive
# the kv geometry from the kv_cache_spec, which in non-EAGLE case agrees with the
# model_config.
self.num_heads_kv = kv_cache_spec.num_kv_heads
self.head_dim = kv_cache_spec.head_size
self.block_size = kv_cache_spec.block_size

self.aot_sliding_window: tuple[int, int] | None = None
Expand Down Expand Up @@ -780,7 +784,7 @@ def build_for_cudagraph_capture(
class AiterMlaMetadataBuilderForVllm(MLACommonMetadataBuilder):
"""vLLM-only dense MLA metadata builder."""

_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold = 1
query_len_support = QueryLenSupport.UNIFORM

Expand Down
Loading