diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 15c8dd6ac..ffdfabcc8 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -568,8 +568,15 @@ def forward( key_states = key_states.transpose(1, 2) if past_key_value is not None: + chunk_postion_ids = position_ids + + if self.use_rope: + chunk_postion_ids = torch.where( + chunk_postion_ids != -1, chunk_postion_ids % self.config.attention_chunk_size, chunk_postion_ids + ) + # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -614,7 +621,7 @@ def forward( residual = hidden_states # use local attention mask for ROPE layers - if self.use_chunked_attention and chunk_causal_mask is not None: + if self.use_chunked_attention: attention_mask = chunk_causal_mask hidden_states = self.input_layernorm(hidden_states) @@ -714,11 +721,14 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) - - _, chunk_causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=past_key_values.key_cache[3].shape[-2] + ) + chunked_position_ids = torch.where( + position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids ) + target_length = min(past_key_values.key_cache[0].shape[-2], torch.tensor(self.config.attention_chunk_size)) + chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=target_length) # embed positions hidden_states = inputs_embeds @@ -905,6 +915,15 @@ def get_specializations( prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + chunk_ctx_len = min( + ctx_len, + ( + self.config.text_config.attention_chunk_size + if hasattr(self, "config") + else constants.LLAMA4_ATTENTION_CHUNK_SIZE + ), + ) + if img_size is None and hasattr(self.config.vision_config, "image_size"): img_size = getattr(self.config.vision_config, "image_size") elif img_size is None: @@ -929,6 +948,8 @@ def get_specializations( "batch_size_times_num_tiles": batch_size_times_num_tiles, "img_size": img_size, "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, }, { "batch_size": batch_size, @@ -937,6 +958,8 @@ def get_specializations( "batch_size_times_num_tiles": batch_size_times_num_tiles, "img_size": img_size, "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, }, ] @@ -958,8 +981,14 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} vision_dynamic_axes["pixel_values"] = {0: "batch_size_times_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} + pkv_dynamic_axes = {0: "batch_size"} for i in range(self.language_model.config.num_hidden_layers): + # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. + if int((i + 1) % 4 != 0): + pkv_dynamic_axes[2] = "chunk_ctx_len" + else: + pkv_dynamic_axes[2] = "ctx_len" + for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 10db37798..b6e2a0f27 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1461,14 +1461,38 @@ def export(self, export_dir: Optional[str] = None) -> str: 0: "full_batch_size" if self.continuous_batching else "batch_size", 2: "ctx_len", } + pkv_dynamic_sliding_axes = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + 2: "chunk_attn", + } + output_names = ["logits"] - for i in range(self.num_layers): + is_chunked_attention = torch.tensor( + [bool((i + 1) % 4) for i in range(self.model.config.num_hidden_layers)], dtype=torch.bool + ) + global_cache_shape = [1, 8, seq_len, 128] + chunked_cache_shape = [ + 1, + 8, + seq_len, + 128, + ] + + for i in range(self.model.config.num_hidden_layers): for kv in ["key", "value"]: - example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + cache_shape = global_cache_shape if not is_chunked_attention[i] else chunked_cache_shape + apply_dynamic_axes = pkv_dynamic_axes if not is_chunked_attention[i] else pkv_dynamic_sliding_axes + example_inputs["past_key_values"][i].append(torch.zeros(cache_shape, dtype=torch.float32)) + dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes output_names.append(f"past_{kv}.{i}_RetainedState") + # for i in range(self.num_layers): + # for kv in ["key", "value"]: + # example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + # dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + # output_names.append(f"past_{kv}.{i}_RetainedState") + if self.continuous_batching: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) dynamic_axes["batch_index"] = {0: "batch_size"} @@ -1497,6 +1521,7 @@ def build_prefill_specialization( "batch_size": 1 if self.continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "chunk_attn": self.model.config.attention_chunk_size, "num_logits_to_keep": 1 if self.is_tlm else None, } if self.continuous_batching: @@ -1522,6 +1547,7 @@ def build_decode_specialization( "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, "ctx_len": ctx_len, + "chunk_attn": self.model.config.attention_chunk_size, "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, } if self.continuous_batching: diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index f6aa3296d..08efb58ac 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -16,6 +16,7 @@ get_onnx_dir_name, get_padding_shape_from_config, get_qpc_dir_path, + get_sliding_window_shapes, hf_download, load_hf_processor, load_hf_tokenizer, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ea09e97d7..d2aae4f9c 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -282,6 +282,64 @@ def padding_check_and_fix(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokeni tokenizer.pad_token_id = tokenizer.vocab_size - 1 +def get_sliding_window_shapes(config, batch_size, seq_len): + """ + Gets padding dims from model config - number of kv heads and d_head + and returns padding shape - (batch_size, number of kv heads, seq_len, hidden size) + required for initialization of past_key_values + -------- + + :config: AutoConfig from pretrained model. + :batch_size: int. number of input prompts used to create inputs + :seq_len: int. sequence length to run the model for. + + Return: + List[int, int, int, int] + """ + + if hasattr(config, "n_head"): # Assuming n_head is a key in the config (GPTs/CodeGen) + n_heads = config.n_head + d_head = config.n_embd // config.n_head + elif hasattr(config, "num_key_value_heads") and hasattr( + config, "num_attention_heads" + ): # Check for num_key_value_heads (Llama/Mistral) + n_heads = config.num_key_value_heads + + if hasattr(config, "head_dim"): + d_head = config.head_dim + else: + d_head = config.hidden_size // config.num_attention_heads + + elif hasattr(config, "n_heads"): # Check for n_heads and d_model in the config (MPT Model) + n_heads = config.n_heads + d_head = config.d_model // config.n_heads + elif hasattr(config, "new_decoder_architecture"): # Check for Falcon + new_decoder_architecture = getattr(config, "new_decoder_architecture") + if new_decoder_architecture: # multi_query is ignored when new_decoder_architecture is True + n_heads = config.num_attention_heads + else: + if hasattr(config, "multi_query"): + multi_query_value = getattr(config, "multi_query") + if multi_query_value: + n_heads = 1 # MQA , multi query is true + else: + n_heads = config.num_attention_heads + d_head = config.hidden_size // config.num_attention_heads + else: + raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.") + + # is_chunked_attention = torch.tensor([bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool) + global_cache_shape = [batch_size, n_heads, seq_len, d_head] + chunked_cache_shape = [ + batch_size, + n_heads, + seq_len if seq_len < config.attention_chunk_size else config.attention_chunk_size, + d_head, + ] + + return global_cache_shape, chunked_cache_shape + + def get_padding_shape_from_config(config, batch_size, seq_len): """ Gets padding dims from model config - number of kv heads and d_head @@ -327,11 +385,29 @@ def get_padding_shape_from_config(config, batch_size, seq_len): d_head = config.hidden_size // config.num_attention_heads else: raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.") - padding_shape = [batch_size, n_heads, seq_len, d_head] - if hasattr(config, "architectures") and config.architectures is not None: # Check for Starcoder1 - 3D layout - if "GPTBigCodeForCausalLM" in config.architectures: - padding_shape = [batch_size, seq_len, d_head] - return padding_shape + + is_chunked_attention = torch.tensor([bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool) + global_cache_shape = [batch_size, n_heads, seq_len, d_head] + chunked_cache_shape = [ + batch_size, + n_heads, + seq_len if seq_len < config.attention_chunk_size else config.attention_chunk_size, + d_head, + ] + + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_shape = global_cache_shape if not is_chunked_attention[i] else chunked_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=torch.float32) + new_layer_value_cache = torch.zeros(cache_shape, dtype=torch.float32) + pkv = (new_layer_key_cache, new_layer_value_cache) + past_key_values.append(pkv) + return past_key_values + # padding_shape = [batch_size, n_heads, seq_len, d_head] + # if hasattr(config, "architectures") and config.architectures is not None: # Check for Starcoder1 - 3D layout + # if "GPTBigCodeForCausalLM" in config.architectures: + # padding_shape = [batch_size, seq_len, d_head] + # return padding_shape def get_num_layers_from_config(config): diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 171517045..ad3161412 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -56,12 +56,6 @@ def get_models_dir(): QEFF_MODELS_DIR = get_models_dir() -ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 -ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 -ONNX_EXPORT_EXAMPLE_FBS = 4 -ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 - COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] # InternVL constants @@ -88,6 +82,7 @@ def get_models_dir(): # Llama4 Constants LLAMA4_NUM_PATCHES = 17 +LLAMA4_ATTENTION_CHUNK_SIZE = 8192 class Constants: diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index c45cfec41..6c1059082 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -8,7 +8,12 @@ import numpy as np import torch -from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix +from QEfficient.utils import ( + get_num_layers_from_config, + get_padding_shape_from_config, + get_sliding_window_shapes, + padding_check_and_fix, +) class InputHandler: @@ -33,7 +38,16 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f self.ctx_len = ctx_len self.full_batch_size = full_batch_size self.n_layer = get_num_layers_from_config(config) - self.padding_shape = get_padding_shape_from_config( + # self.padding_shape = get_padding_shape_from_config( + # config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len + # ) + self.past_key_values = get_padding_shape_from_config( + config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len + ) + self.is_chunked_attention = torch.tensor( + [bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool + ) + self.global_shape, self.sliding_shape = get_sliding_window_shapes( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len ) @@ -76,13 +90,14 @@ def prepare_pytorch_inputs(self): inputs["position_ids"] = torch.arange(input_len).view(1, input_len) inputs["batch_index"] = torch.arange(1).view(-1, 1) - past_key_values = [] - for i in range(self.n_layer): - past_key = torch.zeros((self.padding_shape), dtype=torch.float32) - past_value = torch.zeros((self.padding_shape), dtype=torch.float32) - pkv = (past_key, past_value) - past_key_values.append(pkv) - inputs["past_key_values"] = tuple(past_key_values) + # past_key_values = [] + # for i in range(self.n_layer): + # past_key = torch.zeros((self.padding_shape), dtype=torch.float32) + # past_value = torch.zeros((self.padding_shape), dtype=torch.float32) + # pkv = (past_key, past_value) + # past_key_values.append(pkv) + # inputs["past_key_values"] = tuple(past_key_values) + inputs["past_key_values"] = tuple(self.past_key_values) return inputs @@ -148,9 +163,16 @@ def prepare_ort_inputs(self): axis=1, ).astype(np.int64) + # for i in range(self.n_layer): + # inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) + # inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) + for i in range(self.n_layer): - inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) - inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) + cache_shape = self.global_shape if not self.is_chunked_attention[i] else self.sliding_shape + inputs["past_key." + str(i)] = np.zeros((cache_shape), dtype=np.float32) + inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32) + + return inputs return inputs diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index 267b2bb9e..02d3d45ef 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -101,22 +101,30 @@ def run_hf_model_on_pytorch(self, model_hf): Return: :numpy.ndarray: Generated output tokens """ - input_ids = self.input_handler.tokenizer.encode(self.input_handler.prompt[0], return_tensors="pt") - - input_ids_len = len(input_ids[0]) - - for _ in range(self.gen_len): - outputs = model_hf(input_ids) - logits = outputs.logits[:, -1, :] - predicted_token_id = torch.argmax(logits, dim=-1) - input_ids = torch.cat([input_ids, predicted_token_id.unsqueeze(1)], dim=-1) - - generated_ids = input_ids[0][input_ids_len:].detach().numpy() - generated_text = self.input_handler.tokenizer.decode(generated_ids, skip_special_tokens=True) - print("Original HF Model Outputs (Torch CPU): \n") - print("Prompt:", repr(self.input_handler.prompt)) - print("Completion:", repr(generated_text)) - return generated_ids + # input_ids = self.input_handler.tokenizer.encode(self.input_handler.prompt[0], return_tensors="pt") + + # input_ids_len = len(input_ids[0]) + + # for _ in range(self.gen_len): + # outputs = model_hf(input_ids) + # logits = outputs.logits[:, -1, :] + # predicted_token_id = torch.argmax(logits, dim=-1) + # input_ids = torch.cat([input_ids, predicted_token_id.unsqueeze(1)], dim=-1) + model_inputs = self.input_handler.tokenizer(self.input_handler.prompt[0], return_tensors="pt") + + input_len = model_inputs["input_ids"].shape[-1] + + with torch.inference_mode(): + generation = model_hf.generate(**model_inputs, max_new_tokens=12, do_sample=False) + generation = generation[0][input_len:] + + # generated_ids = input_ids[0][input_ids_len:].detach().numpy() + decoded = self.input_handler.tokenizer.decode(generation, skip_special_tokens=True) + # generated_text = self.input_handler.tokenizer.decode(generated_ids, skip_special_tokens=True) + # print("Original HF Model Outputs (Torch CPU): \n") + # print("Prompt:", repr(self.input_handler.prompt)) + print("Completion:", repr(decoded)) + # return generated_ids def run_kv_model_on_pytorch(self, model): """ diff --git a/pyproject.toml b/pyproject.toml index 99b7c3018..a4cc514d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.11" dependencies = [ - "transformers==4.51.0", + "transformers==4.51.3", "huggingface-hub==0.30.0", "hf_transfer==0.1.9", "peft==0.13.2",