diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 8d62aa463..3425bc636 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -13,7 +13,8 @@ import vllm_gaudi.extension.kernels as kernels import vllm_gaudi.extension.ops as ops from vllm_gaudi.extension.runtime import get_config -from vllm_gaudi.extension.utils import (FP8Matmul, Matmul, ModuleFusedSDPA, Softmax, VLLMFP8KVCache, VLLMKVCache) +from vllm_gaudi.extension.utils import (FP8Matmul, Matmul, B2BMatmul, ModuleFusedSDPA, Softmax, VLLMFP8KVCache, + VLLMKVCache) from vllm.v1.attention.backend import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionType) @@ -226,9 +227,9 @@ def __init__( self.softmax = Softmax() self.matmul_av = Matmul() if not self.enable_fp8_attn \ else FP8Matmul() - self.batch2block_matmul = Matmul() if not self.enable_fp8_attn \ + self.batch2block_matmul = B2BMatmul() if not self.enable_fp8_attn \ else FP8Matmul() - self.block2batch_matmul = Matmul() if not self.enable_fp8_attn \ + self.block2batch_matmul = B2BMatmul() if not self.enable_fp8_attn \ else FP8Matmul() self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \ else VLLMFP8KVCache() @@ -445,9 +446,9 @@ def __init__( self.softmax = Softmax() self.matmul_av = Matmul() if not self.enable_fp8_attn \ else FP8Matmul() - self.batch2block_matmul = Matmul() if not self.enable_fp8_attn \ + self.batch2block_matmul = B2BMatmul() if not self.enable_fp8_attn \ else FP8Matmul() - self.block2batch_matmul = Matmul() if not self.enable_fp8_attn \ + self.block2batch_matmul = B2BMatmul() if not self.enable_fp8_attn \ else FP8Matmul() self.k_cache = VLLMKVCache() if not self.enable_fp8_attn \ else VLLMFP8KVCache() diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index bcdd05b21..41504f570 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -30,6 +30,19 @@ def forward(self, x, y, **kwargs): return torch.matmul(x, y, **kwargs) +class B2BMatmul(Matmul): + """Specialized alias for batch2block and block2batch matmul operations. + + This class remains functionally identical to ``Matmul`` but is used to + semantically mark B2B-related matmuls. This enables the system to apply the + fix that uses the B2B output measurements as the input measurements during + calibration, avoiding corrupted scales from the KV‑cache. + """ + + def __init__(self): + super().__init__() + + class Softmax(torch.nn.Module): def __init__(self):