@@ -1086,7 +1086,7 @@ def _multi_queries_paged_attention_nonkernel(
1086
1086
lengths , # seq_lengths, [batch_size]. nb batch_size = len(seq_lens), the effective kv_length.
1087
1087
page_indices , # [batch_size, pages_per_sequence]
1088
1088
effective_q_lens , # [batch_size], the effective q_length
1089
- attn_logits_soft_cap : float | None = None ,
1089
+ attn_logits_soft_cap : Optional [ float ] = None ,
1090
1090
) -> torch .Tensor : # [batch_size, query_len, num_heads, head_dim]
1091
1091
batch_size , query_len , num_query_heads , head_size = q .shape
1092
1092
num_kv_heads , total_num_pages , page_size , _ = k_pages .shape
@@ -1154,7 +1154,7 @@ def multi_queries_paged_attention(
1154
1154
num_kv_pages_per_compute_block ,
1155
1155
num_queries_per_compute_block ,
1156
1156
use_kernel = True ,
1157
- attn_logits_soft_cap : float | None = None ,
1157
+ attn_logits_soft_cap : Optional [ float ] = None ,
1158
1158
): # [batch_size, query_len, num_heads, head_dim]:
1159
1159
assert len (q .shape ) == 4 , "q should have 4 dimensions."
1160
1160
if not use_kernel :
@@ -1672,8 +1672,8 @@ def paged_attention_xla(q: torch.Tensor,
1672
1672
lengths : torch .Tensor ,
1673
1673
page_indices : torch .Tensor ,
1674
1674
pages_per_compute_block : int ,
1675
- megacore_mode : str | None = None ,
1676
- attn_logits_soft_cap : float | None = None ):
1675
+ megacore_mode : Optional [ str ] = None ,
1676
+ attn_logits_soft_cap : Optional [ float ] = None ):
1677
1677
return paged_attention (q , k_pages , v_pages , lengths , page_indices ,
1678
1678
pages_per_compute_block , megacore_mode ,
1679
1679
attn_logits_soft_cap )
@@ -1686,8 +1686,8 @@ def paged_attention_non_xla(q: torch.Tensor,
1686
1686
lengths : torch .Tensor ,
1687
1687
page_indices : torch .Tensor ,
1688
1688
pages_per_compute_block : int ,
1689
- megacore_mode : str | None = None ,
1690
- attn_logits_soft_cap : float | None = None ):
1689
+ megacore_mode : Optional [ str ] = None ,
1690
+ attn_logits_soft_cap : Optional [ float ] = None ):
1691
1691
return non_xla_attetion (q , k_pages , v_pages , "paged" )
1692
1692
1693
1693
@@ -1698,17 +1698,17 @@ def paged_attention_non_xla(q: torch.Tensor,
1698
1698
1699
1699
1700
1700
@impl (XLA_LIB , "multi_queries_paged_attention" , "XLA" )
1701
- def multi_queries_paged_attention_xla (q : torch . Tensor ,
1702
- k_pages : torch .Tensor ,
1703
- v_pages : torch .Tensor ,
1704
- lengths : torch .Tensor ,
1705
- page_indices : torch .Tensor ,
1706
- effective_q_lens : torch .Tensor ,
1707
- num_kv_pages_per_compute_block : int ,
1708
- num_queries_per_compute_block : int ,
1709
- use_kernel : bool ,
1710
- attn_logits_soft_cap : float |
1711
- None = None ):
1701
+ def multi_queries_paged_attention_xla (
1702
+ q : torch .Tensor ,
1703
+ k_pages : torch .Tensor ,
1704
+ v_pages : torch .Tensor ,
1705
+ lengths : torch .Tensor ,
1706
+ page_indices : torch .Tensor ,
1707
+ effective_q_lens : torch . Tensor ,
1708
+ num_kv_pages_per_compute_block : int ,
1709
+ num_queries_per_compute_block : int ,
1710
+ use_kernel : bool ,
1711
+ attn_logits_soft_cap : Optional [ float ] = None ):
1712
1712
return multi_queries_paged_attention (q , k_pages , v_pages , lengths ,
1713
1713
page_indices , effective_q_lens ,
1714
1714
num_kv_pages_per_compute_block ,
@@ -1717,17 +1717,17 @@ def multi_queries_paged_attention_xla(q: torch.Tensor,
1717
1717
1718
1718
1719
1719
@impl (XLA_LIB , "multi_queries_paged_attention" , "CompositeExplicitAutograd" )
1720
- def multi_queries_paged_attention_non_xla (q : torch . Tensor ,
1721
- k_pages : torch .Tensor ,
1722
- v_pages : torch .Tensor ,
1723
- lengths : torch .Tensor ,
1724
- page_indices : torch .Tensor ,
1725
- effective_q_lens : torch .Tensor ,
1726
- num_kv_pages_per_compute_block : int ,
1727
- num_queries_per_compute_block : int ,
1728
- use_kernel : bool ,
1729
- attn_logits_soft_cap : float |
1730
- None = None ):
1720
+ def multi_queries_paged_attention_non_xla (
1721
+ q : torch .Tensor ,
1722
+ k_pages : torch .Tensor ,
1723
+ v_pages : torch .Tensor ,
1724
+ lengths : torch .Tensor ,
1725
+ page_indices : torch .Tensor ,
1726
+ effective_q_lens : torch . Tensor ,
1727
+ num_kv_pages_per_compute_block : int ,
1728
+ num_queries_per_compute_block : int ,
1729
+ use_kernel : bool ,
1730
+ attn_logits_soft_cap : Optional [ float ] = None ):
1731
1731
return non_xla_attetion (q , k_pages , v_pages , "paged" )
1732
1732
1733
1733
@@ -1751,8 +1751,8 @@ def ragged_paged_attention_xla(
1751
1751
num_queries_per_block : int ,
1752
1752
use_kernel : bool ,
1753
1753
sm_scale : float = 1.0 ,
1754
- mask_value : float | None = None ,
1755
- vmem_limit_bytes : int | None = None ,
1754
+ mask_value : Optional [ float ] = None ,
1755
+ vmem_limit_bytes : Optional [ int ] = None ,
1756
1756
):
1757
1757
return ragged_paged_attention (
1758
1758
q ,
@@ -1782,8 +1782,8 @@ def ragged_paged_attention_non_xla(q: torch.Tensor,
1782
1782
num_queries_per_block : int ,
1783
1783
use_kernel : bool ,
1784
1784
sm_scale : float = 1.0 ,
1785
- mask_value : float | None = None ,
1786
- vmem_limit_bytes : int | None = None ):
1785
+ mask_value : Optional [ float ] = None ,
1786
+ vmem_limit_bytes : Optional [ int ] = None ):
1787
1787
return non_xla_attetion (q , k_pages , v_pages , "paged" )
1788
1788
1789
1789
0 commit comments