diff --git a/tests/full_tests/ci_gsm8k_tests.sh b/tests/full_tests/ci_gsm8k_tests.sh index 04e7f0a9..f9198748 100644 --- a/tests/full_tests/ci_gsm8k_tests.sh +++ b/tests/full_tests/ci_gsm8k_tests.sh @@ -165,16 +165,16 @@ fi echo "Test with deepseek R1 passed" # used to check HPUATTN + MOE + ExpertParallel -# echo "Testing GSM8K on QWEN3-30B-A3B" -# echo VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 TP_SIZE=2 \ -# pytest -v -s vllm-gaudi/tests/models/language/generation/test_common.py --model_card_path vllm-gaudi/tests/full_tests/model_cards/Qwen3-30B-A3B.yaml -# VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 TP_SIZE=2 \ -# pytest -v -s vllm-gaudi/tests/models/language/generation/test_common.py --model_card_path vllm-gaudi/tests/full_tests/model_cards/Qwen3-30B-A3B.yaml -# if [ $? -ne 0 ]; then -# echo "Error: Test failed for QWEN3-30B-A3B" >&2 -# exit -1 -# fi -# echo "Test with QWEN3-30B-A3B passed" +echo "Testing GSM8K on QWEN3-30B-A3B" +echo VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 TP_SIZE=2 \ +pytest -v -s vllm-gaudi/tests/models/language/generation/test_common.py --model_card_path vllm-gaudi/tests/full_tests/model_cards/Qwen3-30B-A3B.yaml +VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 TP_SIZE=2 \ +pytest -v -s vllm-gaudi/tests/models/language/generation/test_common.py --model_card_path vllm-gaudi/tests/full_tests/model_cards/Qwen3-30B-A3B.yaml +if [ $? -ne 0 ]; then + echo "Error: Test failed for QWEN3-30B-A3B" >&2 + exit -1 +fi +echo "Test with QWEN3-30B-A3B passed" # NOTE(Chendi): commented the test, it failed on upstream PR(#24444) # multimodal-support with qwen2.5-vl diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 072e0e83..ff406e3a 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -473,7 +473,8 @@ def forward( k_scale=layer._k_scale_float, v_scale=layer._k_scale_float, ) - + # Set return shape + output_shape = query.shape if query.dim() == 2: if attn_metadata.seq_lens_tensor is not None: batch_size = attn_metadata.seq_lens_tensor.shape[0] if not self.use_merged_prefill else 1 @@ -583,8 +584,8 @@ def forward( position_bias=self.position_bias, **self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size)) - # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + + return output.view(*output_shape) def common_attention_args(self, block_list=None, key_cache=None, value_cache=None, block_size=None): return { diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 0dce81a4..075ed07c 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -82,5 +82,6 @@ def get_features(): Value('dynamic_shapes_compilation', True, env_var='VLLM_T_COMPILE_DYNAMIC_SHAPES', env_var_type=boolean), Value('fullgraph_compilation', False, env_var='VLLM_T_COMPILE_FULLGRAPH', env_var_type=boolean), Value('unified_attn', False), + Value('flatten_input', ModelType('qwen3_moe')), ] return split_values_and_flags(features) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d251b440..67f2aa09 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -326,6 +326,7 @@ def __init__(self, model, vllm_config): self._rotary_embed_module = self._get_rotary_embedding_module(self.model) self._rotary_prepare_cos_sin = self._get_prepare_cos_sin() self.unified_attn = get_config().unified_attn + self.flatten_input = get_config().flatten_input def _get_rotary_embedding_module(self, model: torch.nn.Module): """ @@ -447,6 +448,8 @@ def forward(self, *args, **kwargs): kwargs.update(model_mm_kwargs) num_input_tokens = input_ids.size(0) * input_ids.size(1) + if self.flatten_input: + kwargs['input_ids'] = input_ids.view(-1) with set_forward_context(attn_meta, self.vllm_config, num_tokens=num_input_tokens): hidden_states = self.model(*args, **kwargs) if self._rotary_prepare_cos_sin is not None: