Skip to content

Commit 26c3494

Browse files
authored
[Submodule] Change FlashInfer to import (#156)
1 parent cb8e198 commit 26c3494

File tree

5 files changed

+17
-24
lines changed

5 files changed

+17
-24
lines changed

.gitmodules

-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
[submodule "3rdparty/flashinfer"]
2-
path = 3rdparty/flashinfer
3-
url = https://github.com/flashinfer-ai/flashinfer.git

3rdparty/flashinfer

-1
This file was deleted.

docs/flashinfer.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ It can be used in SGLang runtime to accelerate attention computation.
55

66
### Install flashinfer
77

8-
Note: The compilation can take a very long time.
8+
You can install flashinfer via pip as follows for CUDA 12.1.
99

1010
```bash
11-
git submodule update --init --recursive
12-
pip install 3rdparty/flashinfer/python
11+
pip install flashinfer -i https://flashinfer.ai/whl/cu121/
1312
```
1413

14+
You can look for other CUDA versions in https://github.com/flashinfer-ai/flashinfer?tab=readme-ov-file#installation. If there is no desire version for your environment,
15+
please build it from source (the compilation takes a long time).
16+
1517
### Run a Server With Flashinfer Mode
1618

1719
Add `--model-mode flashinfer` argument to enable flashinfer when launching a server.

python/sglang/srt/layers/radix_attention.py

-8
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,7 @@ def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
9898

9999
o = input_metadata.prefill_wrapper.forward(
100100
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
101-
input_metadata.qo_indptr,
102101
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
103-
input_metadata.kv_indptr,
104-
input_metadata.kv_indices,
105-
input_metadata.kv_last_page_len,
106-
allow_fp16_qk_reduction=True,
107102
)
108103

109104
return o.view(-1, self.tp_q_head_num * self.head_dim)
@@ -114,9 +109,6 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
114109
o = input_metadata.decode_wrapper.forward(
115110
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
116111
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
117-
input_metadata.kv_indptr,
118-
input_metadata.kv_indices,
119-
input_metadata.kv_last_page_len,
120112
)
121113

122114
return o.view(-1, self.tp_q_head_num * self.head_dim)

python/sglang/srt/managers/router/model_runner.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ class InputMetadata:
9090
decode_wrapper = None
9191

9292
def init_flashinfer_args(self, tp_size):
93+
from flashinfer import (
94+
BatchDecodeWithPagedKVCacheWrapper,
95+
BatchPrefillWithPagedKVCacheWrapper,
96+
)
97+
9398
self.kv_indptr = torch.zeros(
9499
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
95100
)
@@ -107,11 +112,7 @@ def init_flashinfer_args(self, tp_size):
107112
(self.batch_size,), dtype=torch.int32, device="cuda"
108113
)
109114

110-
from flashinfer.ops import (
111-
BatchDecodeWithPagedKVCacheWrapper,
112-
BatchPrefillWithPagedKVCacheWrapper,
113-
)
114-
115+
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
115116
if (
116117
self.forward_mode == ForwardMode.PREFILL
117118
or self.forward_mode == ForwardMode.EXTEND
@@ -120,19 +121,21 @@ def init_flashinfer_args(self, tp_size):
120121
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
121122
)
122123
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
123-
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
124+
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
124125
self.prefill_wrapper.begin_forward(
125126
self.qo_indptr,
126-
self.batch_size,
127+
self.kv_indptr,
128+
self.kv_indices,
129+
self.kv_last_page_len,
127130
self.model_runner.model_config.num_attention_heads // tp_size,
128131
self.model_runner.model_config.num_key_value_heads // tp_size,
129132
)
130133
else:
131-
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
134+
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
132135
self.decode_wrapper.begin_forward(
133136
self.kv_indptr,
137+
self.kv_indices,
134138
self.kv_last_page_len,
135-
self.batch_size,
136139
self.model_runner.model_config.num_attention_heads // tp_size,
137140
self.model_runner.model_config.num_key_value_heads // tp_size,
138141
self.model_runner.model_config.head_dim,

0 commit comments

Comments
 (0)