Skip to content
10 changes: 5 additions & 5 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import vllm_gaudi.extension.kernels as kernels
import vllm_gaudi.extension.ops as ops
from vllm_gaudi.extension.runtime import get_config
from vllm_gaudi.extension.utils import (FP8Matmul, Matmul, ModuleFusedSDPA, Softmax, VLLMFP8KVCache, VLLMKVCache)
from vllm_gaudi.extension.utils import (FP8Matmul, Matmul, B2BMatmul, ModuleFusedSDPA, Softmax, VLLMFP8KVCache, VLLMKVCache)

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata,
AttentionType)
Expand Down Expand Up @@ -226,9 +226,9 @@ def __init__(
self.softmax = Softmax()
self.matmul_av = Matmul() if not self.enable_fp8_attn \
else FP8Matmul()
self.batch2block_matmul = Matmul() if not self.enable_fp8_attn \
self.batch2block_matmul = B2BMatmul() if not self.enable_fp8_attn \
else FP8Matmul()
self.block2batch_matmul = Matmul() if not self.enable_fp8_attn \
self.block2batch_matmul = B2BMatmul() if not self.enable_fp8_attn \
else FP8Matmul()
self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \
else VLLMFP8KVCache()
Expand Down Expand Up @@ -445,9 +445,9 @@ def __init__(
self.softmax = Softmax()
self.matmul_av = Matmul() if not self.enable_fp8_attn \
else FP8Matmul()
self.batch2block_matmul = Matmul() if not self.enable_fp8_attn \
self.batch2block_matmul = B2BMatmul() if not self.enable_fp8_attn \
else FP8Matmul()
self.block2batch_matmul = Matmul() if not self.enable_fp8_attn \
self.block2batch_matmul = B2BMatmul() if not self.enable_fp8_attn \
else FP8Matmul()
self.k_cache = VLLMKVCache() if not self.enable_fp8_attn \
else VLLMFP8KVCache()
Expand Down
4 changes: 4 additions & 0 deletions vllm_gaudi/extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __init__(self):
def forward(self, x, y, **kwargs):
return torch.matmul(x, y, **kwargs)

class B2BMatmul(Matmul):
def __init__(self):
super().__init__()


class Softmax(torch.nn.Module):

Expand Down
Loading