Skip to content
72 changes: 72 additions & 0 deletions tests/unit_tests/sinks/test_gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import vllm
import os
from vllm.entrypoints.llm import LLM

RUN_20B_MODEL = True # Set to False to run the 120B model instead
MODEL_PATH = "lmsys/gpt-oss-20b-BF16"
MODEL_PATH_120 = "lmsys/gpt-oss-120b-BF16"
# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L397
original_output = "Roses are red, violets are blue, I love you, and I love you too.\n\nRoses are red, vio"
# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L462
original_output_120 = "Roses are red, violets are blue,\nI am a language model, not a human being"


def do_sample(llm: LLM, original_output: str, rtol: float, atol: float, max_num_seqs: int) -> list[str]:
prompts = [
"Roses are red, violets",
] * max_num_seqs

sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=20,
)
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

assert prompts[0] + generated_texts[0] == original_output, "Generated text does not match the expected output."
return generated_texts


expected_output = [
"are blue, I love you, and I love you too.\n\nRoses are red, vio" # noqa: E501
]


def _test_gpt_oss():
"""Main function that sets up and runs the prompt processing."""
if RUN_20B_MODEL:
llm = LLM(
MODEL_PATH,
max_num_seqs=8,
dtype='bfloat16',
enforce_eager=True,
max_model_len=512,
max_num_batched_tokens=2048,
tensor_parallel_size=1,
)
generated_texts = do_sample(llm, original_output=original_output, rtol=1e-01, atol=1e-01, max_num_seqs=1)
else:
llm = LLM(
MODEL_PATH_120,
max_num_seqs=8,
dtype='bfloat16',
enforce_eager=False,
max_model_len=512,
max_num_batched_tokens=2048,
tensor_parallel_size=4,
)
generated_texts = do_sample(llm, original_output=original_output_120, rtol=1e-01, atol=1e-01, max_num_seqs=1)
assert generated_texts == expected_output
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertion compares single generated text with expected output incorrectly. The function returns a list but only validates the first element earlier. This assertion will fail unless generated_texts contains exactly one element matching expected_output[0]. Consider assert generated_texts[0] == expected_output[0] or assert generated_texts == expected_output after validating the list length.

Suggested change
assert generated_texts == expected_output
assert len(generated_texts) == len(expected_output)
assert generated_texts[0] == expected_output[0]

Copilot uses AI. Check for mistakes.


def test_gpt_oss_1x():
os.environ['VLLM_PROMPT_USE_FUSEDSDPA'] = '0'
_test_gpt_oss()
os.environ['VLLM_PROMPT_USE_FUSEDSDPA'] = '1'
35 changes: 29 additions & 6 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
sinks: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
torch.nn.Module.__init__(self)
Expand Down Expand Up @@ -218,6 +219,11 @@ def __init__(
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

def forward(
self,
Expand Down Expand Up @@ -389,6 +395,7 @@ def __init__(
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> None:
super(AttentionImpl, self).__init__()
if kv_sharing_target_layer_name is not None:
Expand Down Expand Up @@ -453,6 +460,11 @@ def __init__(
raise NotImplementedError("Encoder self-attention "
"is not implemented for "
"HPUAttentionImpl")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

def _maybe_init_alibi_biases(
self,
Expand Down Expand Up @@ -534,6 +546,12 @@ def forward(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
if key.dtype != key_cache.dtype:
key = key.to(key_cache.dtype)
if value.dtype != value_cache.dtype:
value = value.to(value_cache.dtype)
if query.dtype != key.dtype:
query = query.to(key.dtype)
key_cache = self.k_cache(key, key_cache, slot_mapping)
value_cache = self.v_cache(value, value_cache, slot_mapping)

Expand Down Expand Up @@ -570,13 +588,17 @@ def forward(

common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size)

if self.sliding_window and hasattr(attn_metadata,
'window_attn_bias') and attn_metadata.window_attn_bias is not None \
and self.prefill_impl == 'naive_impl':
attn_bias = attn_metadata.window_attn_bias
if self.sliding_window:
if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None:
attn_bias = attn_metadata.window_attn_bias
else:
attn_bias = None
window_size = (self.sliding_window, 0)
common_args['window_size'] = window_size
# TODO - change 128 to proper window size
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent TODO format: should be 'TODO:' with a colon instead of a dash for consistency with project conventions.

Suggested change
# TODO - change 128 to proper window size
# TODO: change 128 to proper window size

Copilot uses AI. Check for mistakes.
window_size = (
128,
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magic number 128 used for window size. Consider defining this as a named constant or deriving it from self.sliding_window as indicated by the TODO comment.

Suggested change
# TODO - change 128 to proper window size
window_size = (
128,
# Use self.sliding_window for window size instead of hardcoded 128
window_size = (
self.sliding_window,

Copilot uses AI. Check for mistakes.
0,
)
common_args["window_size"] = window_size

out = ops.prompt_attention(impl=self.prefill_impl,
query=query.view(query_shape),
Expand Down Expand Up @@ -641,6 +663,7 @@ def common_attention_args(self, block_list=None, key_cache=None, value_cache=Non
'key_cache': key_cache,
'value_cache': value_cache,
'block_size': block_size,
"sinks": self.sinks,
}

def forward_encoder_decoder(
Expand Down
Loading
Loading