Skip to content

Support flashinfer for Gemma3 prefill #3167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail or perhaps speculate about the context of the image?",
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail, or perhaps suggest what this image might represent (e.g",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741965892,
"created": 1744396706,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 98,
"completion_tokens": 100,
"prompt_tokens": 277,
"total_tokens": 375
"total_tokens": 377
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@
"index": 0,
"logprobs": null,
"message": {
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?",
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nIf you'd like, you can give me more details about the image or ask me to focus on a specific aspect of it.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741966313,
"created": 1744396703,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 67,
"completion_tokens": 78,
"prompt_tokens": 277,
"total_tokens": 344
"total_tokens": 355
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
"usage": null
}
],
"created": 1741964480,
"created": 1744396699,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 74,
"prompt_tokens": 275,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
"usage": null
}
],
"created": 1741964477,
"created": 1744396697,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 75,
"prompt_tokens": 279,
Expand Down
21 changes: 16 additions & 5 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,22 @@ struct Config {

impl Config {
fn get_head_dim(&self) -> Option<usize> {
self.head_dim.or_else(|| {
self.text_config
.as_ref()
.and_then(|text_config| text_config.head_dim)
})
if let Some(head_dim) = self.head_dim {
return Some(head_dim);
}

let text_config = self.text_config.as_ref()?;
if let Some(head_size) = text_config.head_dim {
return Some(head_size);
}

match self.model_type.as_deref() {
// We special-case gemma3 here, since we need flashinfer for
// handling bidirectional masks. And flashinfer can only be
// used when the head size is known.
Some("gemma3") => Some(256),
_ => None,
}
}

fn flop(&self) -> Option<u64> {
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/attention/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state(
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
custom_mask: Optional[torch.Tensor],
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
Expand Down Expand Up @@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state(
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
custom_mask=custom_mask,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.models.globals import ATTENTION
from text_generation_server.utils.weights import UnquantizedWeight
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import (
Expand Down Expand Up @@ -248,7 +249,7 @@ def forward(

# Prefill
if cu_seqlen_prefill is not None:
if attention_mask is None:
if attention_mask is None or ATTENTION == "flashinfer":
# flash attention
attn_output = attention(
query=query,
Expand Down Expand Up @@ -701,8 +702,16 @@ def __init__(self, prefix, config, weights):
)

def get_attention_mask(
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
self,
input_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
dtype: torch.dtype,
bool_mask: bool = False,
):
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)

device = input_ids.device
min_dtype = torch.finfo(dtype).min

Expand Down Expand Up @@ -748,9 +757,10 @@ def get_attention_mask(
)
full_attention_mask[:, :, :, :sequence_length] = combined_mask

final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)

return final_attention_mask
if bool_mask:
return full_attention_mask
else:
return torch.where(full_attention_mask, 0, min_dtype).to(device)

def forward(
self,
Expand Down Expand Up @@ -793,10 +803,8 @@ def forward(
)
attention_mask = self.get_attention_mask(
input_ids,
max_s,
cu_seqlen_prefill,
inputs_embeds.dtype,
image_token_mask,
)
# Use flash attention for text-only input
# else:
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2434,6 +2434,7 @@ def _forward_context(
input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor,
state: Optional[Any] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> ContextManager:
if ATTENTION != "flashinfer":
return nullcontext()
Expand All @@ -2450,6 +2451,7 @@ def _forward_context(
),
block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill,
custom_mask=attention_mask,
input_lengths=input_lengths_tensor + cache_lengths_tensor,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
Expand Down
9 changes: 9 additions & 0 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,14 @@ def forward(
)
batch.position_ids = position_ids

if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
# Get the mask, needed for flashinfer.
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
).reshape(-1)
else:
attention_mask = None

# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
Expand All @@ -508,6 +516,7 @@ def forward(
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor,
attention_mask=attention_mask,
):
seqlen = Seqlen(
input_lengths=input_lengths,
Expand Down
Loading