@@ -140,7 +140,7 @@ def forward_native(
140140 query : torch .Tensor ,
141141 key : Optional [torch .Tensor ] = None ,
142142 offsets : Optional [torch .Tensor ] = None ,
143- ) -> Tuple [torch .Tensor , torch .Tensor ]:
143+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
144144 """A PyTorch-native implementation of forward()."""
145145 if offsets is not None :
146146 positions = positions + offsets
@@ -174,7 +174,7 @@ def forward_cuda(
174174 query : torch .Tensor ,
175175 key : Optional [torch .Tensor ] = None ,
176176 offsets : Optional [torch .Tensor ] = None ,
177- ) -> Tuple [torch .Tensor , torch .Tensor ]:
177+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
178178 from vllm import _custom_ops as ops
179179
180180 # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
@@ -202,7 +202,7 @@ def forward_xpu(
202202 query : torch .Tensor ,
203203 key : Optional [torch .Tensor ] = None ,
204204 offsets : Optional [torch .Tensor ] = None ,
205- ) -> Tuple [torch .Tensor , torch .Tensor ]:
205+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
206206 from vllm ._ipex_ops import ipex_ops as ops
207207
208208 self .cos_sin_cache = self .cos_sin_cache .to (positions .device ,
@@ -232,7 +232,7 @@ def forward_hpu(
232232 query : torch .Tensor ,
233233 key : Optional [torch .Tensor ] = None ,
234234 offsets : Optional [torch .Tensor ] = None ,
235- ) -> Tuple [torch .Tensor , torch .Tensor ]:
235+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
236236 from habana_frameworks .torch .hpex .kernels import (
237237 RotaryPosEmbeddingMode , apply_rotary_pos_emb )
238238 if offsets is not None :
@@ -290,7 +290,7 @@ def forward_neuron(
290290 query : torch .Tensor ,
291291 key : Optional [torch .Tensor ] = None ,
292292 offsets : Optional [torch .Tensor ] = None ,
293- ) -> Tuple [torch .Tensor , torch .Tensor ]:
293+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
294294
295295 def _apply_rotary_emb_neuron (
296296 x : torch .Tensor ,
@@ -688,7 +688,7 @@ def forward(
688688 query : torch .Tensor ,
689689 key : Optional [torch .Tensor ] = None ,
690690 offsets : Optional [torch .Tensor ] = None ,
691- ) -> Tuple [torch .Tensor , torch .Tensor ]:
691+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
692692 assert key is not None
693693 query = query .view (* query .shape [:- 1 ], - 1 , self .head_size )
694694 key = key .view (* key .shape [:- 1 ], - 1 , self .head_size )
@@ -799,7 +799,7 @@ def forward(
799799 query : torch .Tensor ,
800800 key : Optional [torch .Tensor ] = None ,
801801 offsets : Optional [torch .Tensor ] = None ,
802- ) -> Tuple [torch .Tensor , torch .Tensor ]:
802+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
803803 """PyTorch-native implementation equivalent to forward()."""
804804 assert key is not None
805805 query_rot = query [..., :self .rotary_dim ]
@@ -929,7 +929,7 @@ def forward(
929929 self ,
930930 query : torch .Tensor ,
931931 key : Optional [torch .Tensor ] = None ,
932- ) -> Tuple [torch .Tensor , torch .Tensor ]:
932+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
933933 assert key is not None
934934 self .cos_sin_cache : torch .Tensor = self .cos_sin_cache .to (query .device )
935935 query_ = torch .view_as_complex (query .float ().reshape (
@@ -975,7 +975,7 @@ def forward(
975975 positions : torch .Tensor ,
976976 query : torch .Tensor ,
977977 key : Optional [torch .Tensor ] = None ,
978- ) -> Tuple [torch .Tensor , torch .Tensor ]:
978+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
979979 """PyTorch-native implementation equivalent to forward().
980980
981981 Args:
0 commit comments