Skip to content
134 changes: 134 additions & 0 deletions test_gpt_oss_offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
import sys
import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
import numpy as np

RUN_20B_MODEL = True # Set to False to run the 120B model instead
MODEL_PATH = "lmsys/gpt-oss-20b-BF16"
MODEL_PATH_120 = "lmsys/gpt-oss-120b-BF16"
# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L397
original_output = "Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio"
# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L462
original_output_120 = "Roses are red, violets are blue,\nI am a language model, not a human being"
original_logprobs = [
-0.037353515625,
-0.08154296875,
-1.21875,
-1.953125,
-2.234375,
-0.96875,
-1.546875,
-1.640625,
-0.93359375,
-1.609375,
-1.625,
-0.85546875,
-1.7265625,
]
original_logprobs_120 = [
-0.90234375,
-0.66015625,
-1.546875,
-2.703125,
-2.078125,
-1.21875,
-2.484375,
-0.031982421875,
-0.84765625,
-1.890625,
-0.1923828125,
-2.046875,
-1.65625,
]


def do_sample(llm: LLM, original_output: str, original_logprobs: list[float], rtol: float, atol: float, max_num_seqs:int) -> list[str]:
prompts = [

Check failure on line 48 in test_gpt_oss_offline.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

test_gpt_oss_offline.py:48:121: E501 Line too long (291 > 120)
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
] * max_num_seqs

sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=512,
logprobs=1 if not PT_PROFILE else None,)
outputs = llm.generate(
prompts,
sampling_params)

if not PT_PROFILE:
# Print the outputs.
generated_texts: list[str] = []
logprobs: list[float] = []
for output in outputs:
for probs in output.outputs[0].logprobs:
logprobs.append(list(probs.values())[0].logprob)
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# assert prompts[0]+generated_texts[0] == original_output, "Generated text does not match the expected output."

Check failure on line 71 in test_gpt_oss_offline.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

test_gpt_oss_offline.py:71:121: E501 Line too long (148 > 120)
# assert np.allclose(np.array(logprobs[:-1]),np.array(original_logprobs),rtol=rtol, atol=atol), "Logprobs do not match the expected values."
return generated_texts
else:
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

if __name__ == "__main__":
DEFAULT_MAX_NUM_SEQS = 1
max_num_seqs = int(sys.argv[1]) if len(sys.argv) > 1 else DEFAULT_MAX_NUM_SEQS
# Enable PyTorch profiling when PT_PROFILE env var is set to one of the values (1,true,yes,on)
_pt_profile_env = os.getenv("PT_PROFILE", "0")
PT_PROFILE = _pt_profile_env.lower() in ("1", "true", "yes", "on")

if RUN_20B_MODEL:
llm = LLM(MODEL_PATH,
max_num_seqs=8 if not PT_PROFILE else max_num_seqs,
dtype='bfloat16',
enforce_eager=True,
max_model_len=512,
max_num_batched_tokens=2048,
tensor_parallel_size=1,
)
if PT_PROFILE:
import torch
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.HPU]
_profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler("./"),
record_shapes=False,
with_stack=False,
)
_profiler.start()
do_sample(llm, original_output=original_output,
original_logprobs=original_logprobs, rtol=1e-01, atol=1e-01, max_num_seqs=max_num_seqs)
_profiler.step()
do_sample(llm, original_output=original_output,
original_logprobs=original_logprobs, rtol=1e-01, atol=1e-01, max_num_seqs=max_num_seqs)
_profiler.step()
do_sample(llm, original_output=original_output,
original_logprobs=original_logprobs, rtol=1e-01, atol=1e-01, max_num_seqs=max_num_seqs)
_profiler.step()
_profiler.stop()
else:
do_sample(llm, original_output=original_output,
original_logprobs=original_logprobs, rtol=1e-01, atol=1e-01, max_num_seqs=max_num_seqs)

else:
llm = LLM(MODEL_PATH_120,
max_num_seqs=8,
dtype='bfloat16',
enforce_eager=False,
max_model_len=512,
max_num_batched_tokens=2048,
tensor_parallel_size=4,
)
do_sample(llm, original_output=original_output_120,
original_logprobs=original_logprobs_120, rtol=1e-01, atol=3e-01, max_num_seqs=max_num_seqs)
82 changes: 82 additions & 0 deletions tests/unit_tests/sinks/test_gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import vllm
import os
from vllm.entrypoints.llm import LLM

RUN_20B_MODEL = True # Set to False to run the 120B model instead
MODEL_PATH = "lmsys/gpt-oss-20b-BF16"
MODEL_PATH_120 = "lmsys/gpt-oss-120b-BF16"
# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L397
original_output = "Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio"
# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L462
original_output_120 = "Roses are red, violets are blue,\nI am a language model, not a human being"


def do_sample(llm: LLM, original_output: str, rtol: float, atol: float, max_num_seqs: int) -> list[str]:
prompts = [
"Roses are red, violets",
] * max_num_seqs

sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=20,
)
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

assert prompts[0] + generated_texts[0] == original_output, "Generated text does not match the expected output."
return generated_texts


expected_output = [
"are blue, I love you, and I love you too.\n\nRoses are red, vio" # noqa: E501
]


def _test_gpt_oss():
"""Main function that sets up and runs the prompt processing."""
if RUN_20B_MODEL:
llm = LLM(
MODEL_PATH,
max_num_seqs=8,
dtype='bfloat16',
enforce_eager=True,
max_model_len=512,
max_num_batched_tokens=2048,
tensor_parallel_size=1,
)
generated_texts = do_sample(llm, original_output=original_output, rtol=1e-01, atol=1e-01, max_num_seqs=1)
else:
llm = LLM(
MODEL_PATH_120,
max_num_seqs=8,
dtype='bfloat16',
enforce_eager=False,
max_model_len=512,
max_num_batched_tokens=2048,
tensor_parallel_size=4,
)
generated_texts = do_sample(llm, original_output=original_output_120, rtol=1e-01, atol=1e-01, max_num_seqs=1)
assert generated_texts == expected_output
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertion compares single generated text with expected output incorrectly. The function returns a list but only validates the first element earlier. This assertion will fail unless generated_texts contains exactly one element matching expected_output[0]. Consider assert generated_texts[0] == expected_output[0] or assert generated_texts == expected_output after validating the list length.

Suggested change
assert generated_texts == expected_output
assert len(generated_texts) == len(expected_output)
assert generated_texts[0] == expected_output[0]

Copilot uses AI. Check for mistakes.


def test_gpt_oss_1x():
os.environ['PT_HPU_ENABLE_FUSED_SDPA_SINK'] = '1'
os.environ['PT_HPU_QKV_SLICE_SEQ_LEN_THLD'] = '64'
os.environ['PT_HPU_SDPA_BR_FACTOR'] = '64'
os.environ['PT_HPU_SDPA_BC_FACTOR'] = '64'
os.environ['PT_HPU_SDPA_QKV_SLICE_MODE_FWD'] = '1'
os.environ['VLLM_FUSEDSDPA_SLIDE_THLD'] = '0'
_test_gpt_oss()
os.environ['PT_HPU_ENABLE_FUSED_SDPA_SINK'] = '0'
os.environ['PT_HPU_QKV_SLICE_SEQ_LEN_THLD'] = '1024'
os.environ['PT_HPU_SDPA_BR_FACTOR'] = '1024'
os.environ['PT_HPU_SDPA_BC_FACTOR'] = '1024'
os.environ['PT_HPU_SDPA_QKV_SLICE_MODE_FWD'] = '0'
os.environ['VLLM_FUSEDSDPA_SLIDE_THLD'] = '8192'
25 changes: 18 additions & 7 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
sinks: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
torch.nn.Module.__init__(self)
Expand Down Expand Up @@ -223,6 +224,11 @@ def __init__(
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

def forward(
self,
Expand Down Expand Up @@ -401,6 +407,7 @@ def __init__(
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> None:
super(AttentionImpl, self).__init__()
if kv_sharing_target_layer_name is not None:
Expand Down Expand Up @@ -465,6 +472,11 @@ def __init__(
raise NotImplementedError("Encoder self-attention "
"is not implemented for "
"HPUAttentionImpl")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

def _maybe_init_alibi_biases(
self,
Expand Down Expand Up @@ -586,13 +598,12 @@ def forward(
common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size,
k_scales, v_scales)

if self.sliding_window:
if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None:
attn_bias = attn_metadata.window_attn_bias
else:
attn_bias = None
window_size = (self.sliding_window, 0)
common_args['window_size'] = window_size
if self.sliding_window and hasattr(attn_metadata,
'window_attn_bias') and attn_metadata.window_attn_bias is not None:
attn_bias = attn_metadata.window_attn_bias
elif self.sliding_window:
window_size = (self.sliding_window, 0)
common_args["window_size"] = window_size

out = ops.prompt_attention(impl=self.prefill_impl,
query=query.view(query_shape),
Expand Down
Loading