Skip to content

Llama4 chunked attention support #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: add_llama4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,15 @@ def forward(
key_states = key_states.transpose(1, 2)

if past_key_value is not None:
chunk_postion_ids = position_ids

if self.use_rope:
chunk_postion_ids = torch.where(
chunk_postion_ids != -1, chunk_postion_ids % self.config.attention_chunk_size, chunk_postion_ids
)

# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
Expand Down Expand Up @@ -614,7 +621,7 @@ def forward(
residual = hidden_states

# use local attention mask for ROPE layers
if self.use_chunked_attention and chunk_causal_mask is not None:
if self.use_chunked_attention:
attention_mask = chunk_causal_mask

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -714,11 +721,14 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)

_, chunk_causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
causal_mask = _create_causal_mask(
position_ids=position_ids, target_length=past_key_values.key_cache[3].shape[-2]
)
chunked_position_ids = torch.where(
position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids
)
target_length = min(past_key_values.key_cache[0].shape[-2], torch.tensor(self.config.attention_chunk_size))
chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=target_length)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -905,6 +915,15 @@ def get_specializations(

prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
chunk_ctx_len = min(
ctx_len,
(
self.config.text_config.attention_chunk_size
if hasattr(self, "config")
else constants.LLAMA4_ATTENTION_CHUNK_SIZE
),
)

if img_size is None and hasattr(self.config.vision_config, "image_size"):
img_size = getattr(self.config.vision_config, "image_size")
elif img_size is None:
Expand All @@ -929,6 +948,8 @@ def get_specializations(
"batch_size_times_num_tiles": batch_size_times_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
},
{
"batch_size": batch_size,
Expand All @@ -937,6 +958,8 @@ def get_specializations(
"batch_size_times_num_tiles": batch_size_times_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
},
]

Expand All @@ -958,8 +981,14 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
lang_dynamic_axes["vision_embeds"] = {0: "vision_size"}
vision_dynamic_axes["pixel_values"] = {0: "batch_size_times_num_tiles", 2: "img_size", 3: "img_size"}

pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
pkv_dynamic_axes = {0: "batch_size"}
for i in range(self.language_model.config.num_hidden_layers):
# switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers.
if int((i + 1) % 4 != 0):
pkv_dynamic_axes[2] = "chunk_ctx_len"
else:
pkv_dynamic_axes[2] = "ctx_len"

for kv in ["key", "value"]:
lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes

Expand Down
32 changes: 29 additions & 3 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,14 +1461,38 @@ def export(self, export_dir: Optional[str] = None) -> str:
0: "full_batch_size" if self.continuous_batching else "batch_size",
2: "ctx_len",
}
pkv_dynamic_sliding_axes = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
2: "chunk_attn",
}

output_names = ["logits"]

for i in range(self.num_layers):
is_chunked_attention = torch.tensor(
[bool((i + 1) % 4) for i in range(self.model.config.num_hidden_layers)], dtype=torch.bool
)
global_cache_shape = [1, 8, seq_len, 128]
chunked_cache_shape = [
1,
8,
seq_len,
128,
]

for i in range(self.model.config.num_hidden_layers):
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
cache_shape = global_cache_shape if not is_chunked_attention[i] else chunked_cache_shape
apply_dynamic_axes = pkv_dynamic_axes if not is_chunked_attention[i] else pkv_dynamic_sliding_axes
example_inputs["past_key_values"][i].append(torch.zeros(cache_shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes
output_names.append(f"past_{kv}.{i}_RetainedState")

# for i in range(self.num_layers):
# for kv in ["key", "value"]:
# example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
# dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
# output_names.append(f"past_{kv}.{i}_RetainedState")

if self.continuous_batching:
example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
dynamic_axes["batch_index"] = {0: "batch_size"}
Expand Down Expand Up @@ -1497,6 +1521,7 @@ def build_prefill_specialization(
"batch_size": 1 if self.continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"chunk_attn": self.model.config.attention_chunk_size,
"num_logits_to_keep": 1 if self.is_tlm else None,
}
if self.continuous_batching:
Expand All @@ -1522,6 +1547,7 @@ def build_decode_specialization(
"batch_size": full_batch_size if self.continuous_batching else batch_size,
"seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1,
"ctx_len": ctx_len,
"chunk_attn": self.model.config.attention_chunk_size,
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None,
}
if self.continuous_batching:
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_onnx_dir_name,
get_padding_shape_from_config,
get_qpc_dir_path,
get_sliding_window_shapes,
hf_download,
load_hf_processor,
load_hf_tokenizer,
Expand Down
86 changes: 81 additions & 5 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,64 @@ def padding_check_and_fix(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokeni
tokenizer.pad_token_id = tokenizer.vocab_size - 1


def get_sliding_window_shapes(config, batch_size, seq_len):
"""
Gets padding dims from model config - number of kv heads and d_head
and returns padding shape - (batch_size, number of kv heads, seq_len, hidden size)
required for initialization of past_key_values
--------

:config: AutoConfig from pretrained model.
:batch_size: int. number of input prompts used to create inputs
:seq_len: int. sequence length to run the model for.

Return:
List[int, int, int, int]
"""

if hasattr(config, "n_head"): # Assuming n_head is a key in the config (GPTs/CodeGen)
n_heads = config.n_head
d_head = config.n_embd // config.n_head
elif hasattr(config, "num_key_value_heads") and hasattr(
config, "num_attention_heads"
): # Check for num_key_value_heads (Llama/Mistral)
n_heads = config.num_key_value_heads

if hasattr(config, "head_dim"):
d_head = config.head_dim
else:
d_head = config.hidden_size // config.num_attention_heads

elif hasattr(config, "n_heads"): # Check for n_heads and d_model in the config (MPT Model)
n_heads = config.n_heads
d_head = config.d_model // config.n_heads
elif hasattr(config, "new_decoder_architecture"): # Check for Falcon
new_decoder_architecture = getattr(config, "new_decoder_architecture")
if new_decoder_architecture: # multi_query is ignored when new_decoder_architecture is True
n_heads = config.num_attention_heads
else:
if hasattr(config, "multi_query"):
multi_query_value = getattr(config, "multi_query")
if multi_query_value:
n_heads = 1 # MQA , multi query is true
else:
n_heads = config.num_attention_heads
d_head = config.hidden_size // config.num_attention_heads
else:
raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.")

# is_chunked_attention = torch.tensor([bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool)
global_cache_shape = [batch_size, n_heads, seq_len, d_head]
chunked_cache_shape = [
batch_size,
n_heads,
seq_len if seq_len < config.attention_chunk_size else config.attention_chunk_size,
d_head,
]

return global_cache_shape, chunked_cache_shape


def get_padding_shape_from_config(config, batch_size, seq_len):
"""
Gets padding dims from model config - number of kv heads and d_head
Expand Down Expand Up @@ -327,11 +385,29 @@ def get_padding_shape_from_config(config, batch_size, seq_len):
d_head = config.hidden_size // config.num_attention_heads
else:
raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.")
padding_shape = [batch_size, n_heads, seq_len, d_head]
if hasattr(config, "architectures") and config.architectures is not None: # Check for Starcoder1 - 3D layout
if "GPTBigCodeForCausalLM" in config.architectures:
padding_shape = [batch_size, seq_len, d_head]
return padding_shape

is_chunked_attention = torch.tensor([bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool)
global_cache_shape = [batch_size, n_heads, seq_len, d_head]
chunked_cache_shape = [
batch_size,
n_heads,
seq_len if seq_len < config.attention_chunk_size else config.attention_chunk_size,
d_head,
]

past_key_values = []
for i in range(config.num_hidden_layers):
cache_shape = global_cache_shape if not is_chunked_attention[i] else chunked_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_value_cache = torch.zeros(cache_shape, dtype=torch.float32)
pkv = (new_layer_key_cache, new_layer_value_cache)
past_key_values.append(pkv)
return past_key_values
# padding_shape = [batch_size, n_heads, seq_len, d_head]
# if hasattr(config, "architectures") and config.architectures is not None: # Check for Starcoder1 - 3D layout
# if "GPTBigCodeForCausalLM" in config.architectures:
# padding_shape = [batch_size, seq_len, d_head]
# return padding_shape


def get_num_layers_from_config(config):
Expand Down
7 changes: 1 addition & 6 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,6 @@ def get_models_dir():

QEFF_MODELS_DIR = get_models_dir()

ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1
ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32
ONNX_EXPORT_EXAMPLE_FBS = 4
ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep
ONNX_EXPORT_OPSET = 13

COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"]

# InternVL constants
Expand All @@ -88,6 +82,7 @@ def get_models_dir():

# Llama4 Constants
LLAMA4_NUM_PATCHES = 17
LLAMA4_ATTENTION_CHUNK_SIZE = 8192


class Constants:
Expand Down
44 changes: 33 additions & 11 deletions QEfficient/utils/generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import numpy as np
import torch

from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix
from QEfficient.utils import (
get_num_layers_from_config,
get_padding_shape_from_config,
get_sliding_window_shapes,
padding_check_and_fix,
)


class InputHandler:
Expand All @@ -33,7 +38,16 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f
self.ctx_len = ctx_len
self.full_batch_size = full_batch_size
self.n_layer = get_num_layers_from_config(config)
self.padding_shape = get_padding_shape_from_config(
# self.padding_shape = get_padding_shape_from_config(
# config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len
# )
self.past_key_values = get_padding_shape_from_config(
config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len
)
self.is_chunked_attention = torch.tensor(
[bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
self.global_shape, self.sliding_shape = get_sliding_window_shapes(
config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len
)

Expand Down Expand Up @@ -76,13 +90,14 @@ def prepare_pytorch_inputs(self):
inputs["position_ids"] = torch.arange(input_len).view(1, input_len)
inputs["batch_index"] = torch.arange(1).view(-1, 1)

past_key_values = []
for i in range(self.n_layer):
past_key = torch.zeros((self.padding_shape), dtype=torch.float32)
past_value = torch.zeros((self.padding_shape), dtype=torch.float32)
pkv = (past_key, past_value)
past_key_values.append(pkv)
inputs["past_key_values"] = tuple(past_key_values)
# past_key_values = []
# for i in range(self.n_layer):
# past_key = torch.zeros((self.padding_shape), dtype=torch.float32)
# past_value = torch.zeros((self.padding_shape), dtype=torch.float32)
# pkv = (past_key, past_value)
# past_key_values.append(pkv)
# inputs["past_key_values"] = tuple(past_key_values)
inputs["past_key_values"] = tuple(self.past_key_values)

return inputs

Expand Down Expand Up @@ -148,9 +163,16 @@ def prepare_ort_inputs(self):
axis=1,
).astype(np.int64)

# for i in range(self.n_layer):
# inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
# inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)

for i in range(self.n_layer):
inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
cache_shape = self.global_shape if not self.is_chunked_attention[i] else self.sliding_shape
inputs["past_key." + str(i)] = np.zeros((cache_shape), dtype=np.float32)
inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32)

return inputs

return inputs

Expand Down
Loading