diff --git a/test_gpt_oss_offline.py b/test_gpt_oss_offline.py new file mode 100644 index 000000000..20a239f63 --- /dev/null +++ b/test_gpt_oss_offline.py @@ -0,0 +1,134 @@ +import os +import sys +import vllm +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.entrypoints.llm import LLM +import numpy as np + +RUN_20B_MODEL = True # Set to False to run the 120B model instead +MODEL_PATH = "lmsys/gpt-oss-20b-BF16" +MODEL_PATH_120 = "lmsys/gpt-oss-120b-BF16" +# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L397 +original_output = "Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio" +# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L462 +original_output_120 = "Roses are red, violets are blue,\nI am a language model, not a human being" +original_logprobs = [ + -0.037353515625, + -0.08154296875, + -1.21875, + -1.953125, + -2.234375, + -0.96875, + -1.546875, + -1.640625, + -0.93359375, + -1.609375, + -1.625, + -0.85546875, + -1.7265625, + ] +original_logprobs_120 = [ + -0.90234375, + -0.66015625, + -1.546875, + -2.703125, + -2.078125, + -1.21875, + -2.484375, + -0.031982421875, + -0.84765625, + -1.890625, + -0.1923828125, + -2.046875, + -1.65625, + ] + + +def do_sample(llm: LLM, original_output: str, original_logprobs: list[float], rtol: float, atol: float, max_num_seqs:int) -> list[str]: + prompts = [ + "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + ] * max_num_seqs + + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=512, + logprobs=1 if not PT_PROFILE else None,) + outputs = llm.generate( + prompts, + sampling_params) + + if not PT_PROFILE: + # Print the outputs. + generated_texts: list[str] = [] + logprobs: list[float] = [] + for output in outputs: + for probs in output.outputs[0].logprobs: + logprobs.append(list(probs.values())[0].logprob) + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # assert prompts[0]+generated_texts[0] == original_output, "Generated text does not match the expected output." + # assert np.allclose(np.array(logprobs[:-1]),np.array(original_logprobs),rtol=rtol, atol=atol), "Logprobs do not match the expected values." + return generated_texts + else: + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +if __name__ == "__main__": + DEFAULT_MAX_NUM_SEQS = 1 + max_num_seqs = int(sys.argv[1]) if len(sys.argv) > 1 else DEFAULT_MAX_NUM_SEQS + # Enable PyTorch profiling when PT_PROFILE env var is set to one of the values (1,true,yes,on) + _pt_profile_env = os.getenv("PT_PROFILE", "0") + PT_PROFILE = _pt_profile_env.lower() in ("1", "true", "yes", "on") + + if RUN_20B_MODEL: + llm = LLM(MODEL_PATH, + max_num_seqs=8 if not PT_PROFILE else max_num_seqs, + dtype='bfloat16', + enforce_eager=True, + max_model_len=512, + max_num_batched_tokens=2048, + tensor_parallel_size=1, + ) + if PT_PROFILE: + import torch + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.HPU] + _profiler = torch.profiler.profile( + schedule=schedule, + activities=activities, + on_trace_ready=torch.profiler.tensorboard_trace_handler("./"), + record_shapes=False, + with_stack=False, + ) + _profiler.start() + do_sample(llm, original_output=original_output, + original_logprobs=original_logprobs, rtol=1e-01, atol=1e-01, max_num_seqs=max_num_seqs) + _profiler.step() + do_sample(llm, original_output=original_output, + original_logprobs=original_logprobs, rtol=1e-01, atol=1e-01, max_num_seqs=max_num_seqs) + _profiler.step() + do_sample(llm, original_output=original_output, + original_logprobs=original_logprobs, rtol=1e-01, atol=1e-01, max_num_seqs=max_num_seqs) + _profiler.step() + _profiler.stop() + else: + do_sample(llm, original_output=original_output, + original_logprobs=original_logprobs, rtol=1e-01, atol=1e-01, max_num_seqs=max_num_seqs) + + else: + llm = LLM(MODEL_PATH_120, + max_num_seqs=8, + dtype='bfloat16', + enforce_eager=False, + max_model_len=512, + max_num_batched_tokens=2048, + tensor_parallel_size=4, + ) + do_sample(llm, original_output=original_output_120, + original_logprobs=original_logprobs_120, rtol=1e-01, atol=3e-01, max_num_seqs=max_num_seqs) diff --git a/tests/unit_tests/sinks/test_gpt_oss.py b/tests/unit_tests/sinks/test_gpt_oss.py new file mode 100644 index 000000000..b90ff7955 --- /dev/null +++ b/tests/unit_tests/sinks/test_gpt_oss.py @@ -0,0 +1,82 @@ +import vllm +import os +from vllm.entrypoints.llm import LLM + +RUN_20B_MODEL = True # Set to False to run the 120B model instead +MODEL_PATH = "lmsys/gpt-oss-20b-BF16" +MODEL_PATH_120 = "lmsys/gpt-oss-120b-BF16" +# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L397 +original_output = "Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio" +# reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L462 +original_output_120 = "Roses are red, violets are blue,\nI am a language model, not a human being" + + +def do_sample(llm: LLM, original_output: str, rtol: float, atol: float, max_num_seqs: int) -> list[str]: + prompts = [ + "Roses are red, violets", + ] * max_num_seqs + + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=20, + ) + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + assert prompts[0] + generated_texts[0] == original_output, "Generated text does not match the expected output." + return generated_texts + + +expected_output = [ + "are blue, I love you, and I love you too.\n\nRoses are red, vio" # noqa: E501 +] + + +def _test_gpt_oss(): + """Main function that sets up and runs the prompt processing.""" + if RUN_20B_MODEL: + llm = LLM( + MODEL_PATH, + max_num_seqs=8, + dtype='bfloat16', + enforce_eager=True, + max_model_len=512, + max_num_batched_tokens=2048, + tensor_parallel_size=1, + ) + generated_texts = do_sample(llm, original_output=original_output, rtol=1e-01, atol=1e-01, max_num_seqs=1) + else: + llm = LLM( + MODEL_PATH_120, + max_num_seqs=8, + dtype='bfloat16', + enforce_eager=False, + max_model_len=512, + max_num_batched_tokens=2048, + tensor_parallel_size=4, + ) + generated_texts = do_sample(llm, original_output=original_output_120, rtol=1e-01, atol=1e-01, max_num_seqs=1) + assert generated_texts == expected_output + + +def test_gpt_oss_1x(): + os.environ['PT_HPU_ENABLE_FUSED_SDPA_SINK'] = '1' + os.environ['PT_HPU_QKV_SLICE_SEQ_LEN_THLD'] = '64' + os.environ['PT_HPU_SDPA_BR_FACTOR'] = '64' + os.environ['PT_HPU_SDPA_BC_FACTOR'] = '64' + os.environ['PT_HPU_SDPA_QKV_SLICE_MODE_FWD'] = '1' + os.environ['VLLM_FUSEDSDPA_SLIDE_THLD'] = '0' + _test_gpt_oss() + os.environ['PT_HPU_ENABLE_FUSED_SDPA_SINK'] = '0' + os.environ['PT_HPU_QKV_SLICE_SEQ_LEN_THLD'] = '1024' + os.environ['PT_HPU_SDPA_BR_FACTOR'] = '1024' + os.environ['PT_HPU_SDPA_BC_FACTOR'] = '1024' + os.environ['PT_HPU_SDPA_QKV_SLICE_MODE_FWD'] = '0' + os.environ['VLLM_FUSEDSDPA_SLIDE_THLD'] = '8192' diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index c43251d97..96146f718 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -171,6 +171,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + sinks: Optional[torch.Tensor] = None, **kwargs, ) -> None: torch.nn.Module.__init__(self) @@ -223,6 +224,11 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "TritonMLAImpl") + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") def forward( self, @@ -401,6 +407,7 @@ def __init__( attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, + sinks: Optional[torch.Tensor] = None, ) -> None: super(AttentionImpl, self).__init__() if kv_sharing_target_layer_name is not None: @@ -465,6 +472,11 @@ def __init__( raise NotImplementedError("Encoder self-attention " "is not implemented for " "HPUAttentionImpl") + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") def _maybe_init_alibi_biases( self, @@ -586,13 +598,12 @@ def forward( common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size, k_scales, v_scales) - if self.sliding_window: - if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None: - attn_bias = attn_metadata.window_attn_bias - else: - attn_bias = None - window_size = (self.sliding_window, 0) - common_args['window_size'] = window_size + if self.sliding_window and hasattr(attn_metadata, + 'window_attn_bias') and attn_metadata.window_attn_bias is not None: + attn_bias = attn_metadata.window_attn_bias + elif self.sliding_window: + window_size = (self.sliding_window, 0) + common_args["window_size"] = window_size out = ops.prompt_attention(impl=self.prefill_impl, query=query.view(query_shape), diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 51b42d6b5..a4deb06d6 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -64,8 +64,8 @@ def matmul_shape(lhs, rhs): return result -def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_size, matmul_av_op, batch2block_matmul_op, - block2batch_matmul_op): +def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, block_size, batch_size, matmul_av_op, + batch2block_matmul_op, block2batch_matmul_op): # When fp32_softmax is enabled attn is left in fp32 after Q@K # We can return to native dtype after we renormalize and calculate the adjustments if block_bias is not None and attn.dtype != block_bias.dtype: @@ -79,11 +79,27 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_siz if block_bias is not None: attn.add_(block_bias) block_max = attn.amax(dim=-1, keepdim=True) + if sink is not None: + block_max = torch.maximum(block_max, sink) attn = attn.sub(block_max) attn = attn.exp() if attn.dtype == torch.float32: attn = attn.to(value.dtype) - block_sums = attn.sum(dim=-1, keepdim=True) + attn_shape = attn.shape + block_sums = attn.view(-1, attn_shape[-1]).sum(dim=-1, keepdim=True) + attn_shape = list(attn_shape) + attn_shape[-1] = 1 + block_sums = block_sums.view(attn_shape) + if sink is not None: + attn_sink = sink.sub(block_max) + attn_sink = attn_sink.exp() + if attn_sink.dtype == torch.float32: + attn_sink = attn_sink.to(value.dtype) + #TODO: Removing this .sum and using attn_sink directly + #results in wrong output which does not make sense. + #Looks like a Synapse issue, need to investigate further. + block_sums_sink = attn_sink.sum(dim=-1, keepdim=True) + block_sums = block_sums + block_sums_sink attn = matmul_av_op(attn, value) if get_config().fused_block_softmax_adjustment: out_shape = list(attn.shape[:3]) + [1] * (attn.dim() - 3) @@ -194,6 +210,13 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias value = values_fetch_func(value_cache.unflatten(0, (-1, block_size)), **get_kv_fetch_extra_args(blocks=block_list, scales=v_scales_uf)).transpose(1, 2) block_bias = block_bias.view(key.size(0), 1, 1, -1) + sink = None + if sinks is not None: + sinks = sinks.reshape(sinks.shape[0], 1) + sink = sinks.reshape(1, sinks.shape[0], 1, sinks.shape[1]) + sink = sink.expand(query.shape[0], -1, query.shape[-2], -1) + if kv_heads != q_heads: + sink = sink.unflatten(1, (kv_heads, -1)) if kv_heads != q_heads: query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) @@ -231,6 +254,8 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias block_bias, block_groups, block_mapping, + sink, + block_size, batch_size=batch_size, matmul_av_op=matmul_av_op, batch2block_matmul_op=batch2block_matmul_op, @@ -289,6 +314,7 @@ def _naive_prompt_attention(query: torch.Tensor, matmul_qk_op=torch.matmul, softmax_op=torch.softmax, matmul_av_op=torch.matmul, + sinks: Optional[torch.Tensor] = None, **ignored_args) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -320,10 +346,19 @@ def _naive_prompt_attention(query: torch.Tensor, if attn_weights.dtype != attn_bias.dtype: attn_bias = attn_bias.to(dtype=attn_weights.dtype) attn_weights.add_(attn_bias) + if sinks is not None: + sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + if query_heads != kv_heads: + sink = sink.unflatten(1, (kv_heads, -1)) + combined_logits = torch.cat([attn_weights, sink], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + attn_weights = combined_logits if get_config().fp32_softmax: attn_weights = torch.softmax(attn_weights, dim=-1) else: attn_weights = softmax_op(attn_weights, dim=-1) + if sinks is not None: + attn_weights = attn_weights[..., :-1] attn_weights = attn_weights.to(query.dtype) attn_weights = matmul_av_op(attn_weights, value) @@ -342,6 +377,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, valid_seq_lengths: Optional[torch.Tensor] = None, window_size: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, **ignored_args) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -358,15 +394,24 @@ def _fsdpa_prompt_attention(query: torch.Tensor, # TODO: causal + attn_bias is not yet supported is_causal = False valid_seq_lengths = None - + # TODO - remove this once fsdpa op support fast mode for sliding window + if window_size is not None: + #causal window sdpa kernel only supports softmax None + softmax_mode = 'None' args = [ query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, padding_side ] - args += [window_size] if window_size else [] - attn_weights = fsdpa_op(*args) + args += [window_size] if window_size else [None] + # use sinks in fsdpa + if sinks is not None: + args += [sinks] + attn_weights = fsdpa_op(*args) attn_weights = attn_weights.transpose(1, 2) + if sinks is not None: + # TODO - check if we can remove this + htcore.mark_step() return attn_weights @@ -486,6 +531,9 @@ def __init__(self): def set_weight(self, w): self.weight = w + def set_bias(self, b): + self.bias = b + def forward(self, state, expert_id, w): raise NotImplementedError() diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index 59e92b285..d0005ba75 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -157,14 +157,16 @@ def forward( valid_sequence_lengths, padding_side="left", window_size=None, + sinks=None, ): - if window_size is not None: + if window_size: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side, False, False, - window_size) + window_size, sinks) else: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, - recompute_mode, valid_sequence_lengths, padding_side) + recompute_mode, valid_sequence_lengths, padding_side, False, False, + (-1, -1), sinks) def pad_list(input, target_len, val_generator): diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index ba0b8d96b..4c58c771f 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -15,12 +15,15 @@ class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) torch.hpu.synchronize() + vllm_config = get_current_vllm_config() + self.model_type = vllm_config.model_config.hf_config.model_type def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) # custom handling for HPU num_experts = layer.local_num_experts ep_shift = layer.ep_rank * num_experts + has_bias = hasattr(layer, 'w13_bias') and hasattr(layer, 'w2_bias') experts_min, experts_max = ep_shift, num_experts + ep_shift - 1 layer.moe_op = VllmMixtureOfExpertsOp( @@ -28,11 +31,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts, experts_min, experts_max, + bias=has_bias, ) for expert_id in range(layer.local_num_experts): layer.moe_op.w13_list[expert_id].set_weight(layer.w13_weight.data[expert_id]) layer.moe_op.w2_list[expert_id].set_weight(layer.w2_weight.data[expert_id]) + if has_bias: + layer.moe_op.w13_list[expert_id].set_bias(layer.w13_bias.data[expert_id]) + layer.moe_op.w2_list[expert_id].set_bias(layer.w2_bias.data[expert_id]) def forward_oot( self,