Skip to content

Commit 839ac4c

Browse files
authored
Fix 3.9 Python syntax (#9018)
1 parent 5d3c70e commit 839ac4c

File tree

5 files changed

+46
-42
lines changed

5 files changed

+46
-42
lines changed

test/test_multi_queries_paged_attention_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention
66
import jax.numpy as jnp
77
import numpy as np
8+
from typing import Optional
89

910
jax.config.parse_flags_with_absl()
1011

@@ -45,7 +46,7 @@ def _ref_jax_extended_paged_attention(
4546
lengths, # [batch_size], the effective kv_length.
4647
page_indices, # [batch_size, pages_per_sequence]
4748
effective_q_lens, # [batch_size] the effective q_length
48-
attn_logits_soft_cap: float | None = None,
49+
attn_logits_soft_cap: Optional[float] = None,
4950
):
5051
batch_size, query_len, num_query_heads, head_size = q.shape
5152
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape

torch_xla/experimental/custom_kernel.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,7 @@ def _multi_queries_paged_attention_nonkernel(
10861086
lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens), the effective kv_length.
10871087
page_indices, # [batch_size, pages_per_sequence]
10881088
effective_q_lens, # [batch_size], the effective q_length
1089-
attn_logits_soft_cap: float | None = None,
1089+
attn_logits_soft_cap: Optional[float] = None,
10901090
) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim]
10911091
batch_size, query_len, num_query_heads, head_size = q.shape
10921092
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
@@ -1154,7 +1154,7 @@ def multi_queries_paged_attention(
11541154
num_kv_pages_per_compute_block,
11551155
num_queries_per_compute_block,
11561156
use_kernel=True,
1157-
attn_logits_soft_cap: float | None = None,
1157+
attn_logits_soft_cap: Optional[float] = None,
11581158
): # [batch_size, query_len, num_heads, head_dim]:
11591159
assert len(q.shape) == 4, "q should have 4 dimensions."
11601160
if not use_kernel:
@@ -1672,8 +1672,8 @@ def paged_attention_xla(q: torch.Tensor,
16721672
lengths: torch.Tensor,
16731673
page_indices: torch.Tensor,
16741674
pages_per_compute_block: int,
1675-
megacore_mode: str | None = None,
1676-
attn_logits_soft_cap: float | None = None):
1675+
megacore_mode: Optional[str] = None,
1676+
attn_logits_soft_cap: Optional[float] = None):
16771677
return paged_attention(q, k_pages, v_pages, lengths, page_indices,
16781678
pages_per_compute_block, megacore_mode,
16791679
attn_logits_soft_cap)
@@ -1686,8 +1686,8 @@ def paged_attention_non_xla(q: torch.Tensor,
16861686
lengths: torch.Tensor,
16871687
page_indices: torch.Tensor,
16881688
pages_per_compute_block: int,
1689-
megacore_mode: str | None = None,
1690-
attn_logits_soft_cap: float | None = None):
1689+
megacore_mode: Optional[str] = None,
1690+
attn_logits_soft_cap: Optional[float] = None):
16911691
return non_xla_attetion(q, k_pages, v_pages, "paged")
16921692

16931693

@@ -1698,17 +1698,17 @@ def paged_attention_non_xla(q: torch.Tensor,
16981698

16991699

17001700
@impl(XLA_LIB, "multi_queries_paged_attention", "XLA")
1701-
def multi_queries_paged_attention_xla(q: torch.Tensor,
1702-
k_pages: torch.Tensor,
1703-
v_pages: torch.Tensor,
1704-
lengths: torch.Tensor,
1705-
page_indices: torch.Tensor,
1706-
effective_q_lens: torch.Tensor,
1707-
num_kv_pages_per_compute_block: int,
1708-
num_queries_per_compute_block: int,
1709-
use_kernel: bool,
1710-
attn_logits_soft_cap: float |
1711-
None = None):
1701+
def multi_queries_paged_attention_xla(
1702+
q: torch.Tensor,
1703+
k_pages: torch.Tensor,
1704+
v_pages: torch.Tensor,
1705+
lengths: torch.Tensor,
1706+
page_indices: torch.Tensor,
1707+
effective_q_lens: torch.Tensor,
1708+
num_kv_pages_per_compute_block: int,
1709+
num_queries_per_compute_block: int,
1710+
use_kernel: bool,
1711+
attn_logits_soft_cap: Optional[float] = None):
17121712
return multi_queries_paged_attention(q, k_pages, v_pages, lengths,
17131713
page_indices, effective_q_lens,
17141714
num_kv_pages_per_compute_block,
@@ -1717,17 +1717,17 @@ def multi_queries_paged_attention_xla(q: torch.Tensor,
17171717

17181718

17191719
@impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd")
1720-
def multi_queries_paged_attention_non_xla(q: torch.Tensor,
1721-
k_pages: torch.Tensor,
1722-
v_pages: torch.Tensor,
1723-
lengths: torch.Tensor,
1724-
page_indices: torch.Tensor,
1725-
effective_q_lens: torch.Tensor,
1726-
num_kv_pages_per_compute_block: int,
1727-
num_queries_per_compute_block: int,
1728-
use_kernel: bool,
1729-
attn_logits_soft_cap: float |
1730-
None = None):
1720+
def multi_queries_paged_attention_non_xla(
1721+
q: torch.Tensor,
1722+
k_pages: torch.Tensor,
1723+
v_pages: torch.Tensor,
1724+
lengths: torch.Tensor,
1725+
page_indices: torch.Tensor,
1726+
effective_q_lens: torch.Tensor,
1727+
num_kv_pages_per_compute_block: int,
1728+
num_queries_per_compute_block: int,
1729+
use_kernel: bool,
1730+
attn_logits_soft_cap: Optional[float] = None):
17311731
return non_xla_attetion(q, k_pages, v_pages, "paged")
17321732

17331733

@@ -1751,8 +1751,8 @@ def ragged_paged_attention_xla(
17511751
num_queries_per_block: int,
17521752
use_kernel: bool,
17531753
sm_scale: float = 1.0,
1754-
mask_value: float | None = None,
1755-
vmem_limit_bytes: int | None = None,
1754+
mask_value: Optional[float] = None,
1755+
vmem_limit_bytes: Optional[int] = None,
17561756
):
17571757
return ragged_paged_attention(
17581758
q,
@@ -1782,8 +1782,8 @@ def ragged_paged_attention_non_xla(q: torch.Tensor,
17821782
num_queries_per_block: int,
17831783
use_kernel: bool,
17841784
sm_scale: float = 1.0,
1785-
mask_value: float | None = None,
1786-
vmem_limit_bytes: int | None = None):
1785+
mask_value: Optional[float] = None,
1786+
vmem_limit_bytes: Optional[int] = None):
17871787
return non_xla_attetion(q, k_pages, v_pages, "paged")
17881788

17891789

torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Sequence
44
import functools
5-
from typing import Literal, cast
5+
from typing import Literal, cast, Union, Optional
66

77
import jax
88
from jax import lax
@@ -116,7 +116,7 @@ def _flash_attention(
116116
query_len: int,
117117
page_size: int,
118118
head_dim: int,
119-
attn_logits_soft_cap: float | None,
119+
attn_logits_soft_cap: Optional[float],
120120
):
121121
b, kv_head_idx, q_blk_idx, kv_blk_idx = (
122122
pl.program_id(0),
@@ -271,7 +271,7 @@ def paged_flash_attention_kernel(
271271
num_kv_pages_per_compute_block: int,
272272
mask_value: float,
273273
query_len: int,
274-
attn_logits_soft_cap: float | None,
274+
attn_logits_soft_cap: Optional[float],
275275
):
276276
"""Pallas kernel for paged attention."""
277277
b, kv_head_idx, q_blk_idx, kv_blk_idx = (
@@ -440,16 +440,16 @@ def prefetch_next_block(): # pylint: disable=unused-variable
440440
)
441441
def paged_attention(
442442
q: jax.Array,
443-
k_pages: jax.Array | quantization_utils.QuantizedTensor,
444-
v_pages: jax.Array | quantization_utils.QuantizedTensor,
443+
k_pages: Union[jax.Array, quantization_utils.QuantizedTensor],
444+
v_pages: Union[jax.Array, quantization_utils.QuantizedTensor],
445445
lengths: jax.Array,
446446
page_indices: jax.Array,
447447
effective_q_lens: jax.Array,
448448
*,
449449
mask_value: float = DEFAULT_MASK_VALUE,
450450
num_kv_pages_per_compute_block: int,
451451
num_queries_per_compute_block: int = 4,
452-
attn_logits_soft_cap: float | None = None,
452+
attn_logits_soft_cap: Optional[float] = None,
453453
) -> jax.Array:
454454
"""Paged grouped query attention.
455455

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
"""
88

99
import functools
10+
from typing import Optional
11+
1012
import jax
1113
from jax import lax
1214
from jax.experimental import pallas as pl
@@ -571,7 +573,7 @@ def ragged_paged_attention(
571573
mask_value: float = DEFAULT_MASK_VALUE,
572574
num_kv_pages_per_block: int = 16,
573575
num_queries_per_block: int = 128,
574-
vmem_limit_bytes: int | None = None,
576+
vmem_limit_bytes: Optional[int] = None,
575577
):
576578
"""Ragged paged attention that supports mixed prefill and decode.
577579

torchax/examples/train_llama_torchtitan/splash_attn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
from typing import Optional
23

34
import jax
45
import jax.numpy as jnp
@@ -16,8 +17,8 @@ def tpu_splash_attention(
1617
query: jax.Array,
1718
key: jax.Array,
1819
value: jax.Array,
19-
decoder_segment_ids: jax.Array | None,
20-
attn_logits_soft_cap: float | None = None,
20+
decoder_segment_ids: Optional[jax.Array],
21+
attn_logits_soft_cap: Optional[float] = None,
2122
) -> jax.Array:
2223
"""TPU Flash Attention."""
2324
if decoder_segment_ids is not None:

0 commit comments

Comments
 (0)