Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,8 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.num_hidden_layers)]
for i in range(self.model.config.num_hidden_layers):
lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)]
for i in range(self.model.config.text_config.num_hidden_layers):
for kv in ["key", "value"]:
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))

Expand All @@ -779,10 +779,10 @@ def get_specializations(
**compiler_options,
):
if height is None or width is None:
height = 1365
width = 2048
height = constants.QWEN2_5_VL_HEIGHT
width = constants.QWEN2_5_VL_WIDTH
logger.warning(
"Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config"
f"Setting height and width to be {height} and {width} respectively, as it was neither passed nor found in vision_config"
)
prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
Expand Down Expand Up @@ -882,7 +882,7 @@ def smart_resize(

def get_onnx_dynamic_axes(self, kv_offload: bool = False):
# Define dynamic axes
num_layers = self.config.num_hidden_layers
num_layers = self.config.text_config.num_hidden_layers

vision_dynamic_axes = {
"pixel_values": {0: "grid_height", 1: "grid_width"},
Expand All @@ -900,6 +900,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}

dynamic_axes = {}

if kv_offload:
dynamic_axes["vision"] = vision_dynamic_axes
dynamic_axes["lang"] = lang_dynamic_axes
Expand All @@ -911,7 +912,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
def get_output_names(self, kv_offload: bool = False):
vision_output_names = ["vision_embeds"]
lang_output_names = ["logits"]
for i in range(self.model.config.num_hidden_layers):
for i in range(self.model.config.text_config.num_hidden_layers):
for kv in ["key", "value"]:
lang_output_names.append(f"past_{kv}.{i}_RetainedState")

Expand All @@ -927,6 +928,32 @@ def get_output_names(self, kv_offload: bool = False):
return lang_output_names
return output_names

def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=1):
input_ids_length = inputs["input_ids"].shape[1]

inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1)

pos_ids, rope_deltas = self.model.get_rope_index(
inputs["input_ids"],
None if "image_grid_thw" not in inputs else inputs["image_grid_thw"],
video_grid_thw=None,
second_per_grid_ts=None,
attention_mask=inputs["attention_mask"],
)

inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0)

num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len

inputs["position_ids"] = F.pad(
inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1
)

inputs.pop("image_grid_thw", None)

return inputs

def get_inputs_info(self):
return [
IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")),
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def get_models_dir():
# Wav2Vec2 Constant
WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec)

# Qwen2_5_vl Constants
QWEN2_5_VL_HEIGHT = 354
QWEN2_5_VL_WIDTH = 536


class Constants:
# Export Constants.
Expand Down
51 changes: 5 additions & 46 deletions examples/qwen2_5_vl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
# -----------------------------------------------------------------------------

import requests
import torch
import torch.nn.functional as F
import transformers
from PIL import Image
from qwen_vl_utils import process_vision_info
Expand All @@ -18,8 +16,7 @@
## For AWQ model update pytorch version to 2.8.*
model_id = "Qwen/Qwen2.5-VL-32B-Instruct"
config = AutoConfig.from_pretrained(model_id)

## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model
config.text_config.num_hidden_layers = 2

qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
model_id, attn_implementation="eager", kv_offload=True, config=config
Expand All @@ -28,13 +25,13 @@
processor = AutoProcessor.from_pretrained(model_id)

### use skip_vision=Ture, if want to run only text, ow false ###
skip_vision = False
skip_vision = True

if skip_vision:
## Only Text ##

## Set Batch_Size ##
batch_size = 2
batch_size = 1
qeff_model.compile(
batch_size=batch_size,
prefill_seq_len=128,
Expand Down Expand Up @@ -68,25 +65,7 @@
return_tensors="pt",
)

pos_ids, rope_deltas = qeff_model.model.get_rope_index(
inputs["input_ids"],
image_grid_thw=None,
video_grid_thw=None,
second_per_grid_ts=None,
attention_mask=inputs["attention_mask"],
)

input_ids_length = inputs["input_ids"].shape[1]

inputs["position_ids"] = torch.cat([pos_ids, pos_ids[0].unsqueeze(0)], dim=0)

prefill_seq_len = 128
num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len

inputs["position_ids"] = F.pad(
inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1
)
inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size)

streamer = TextStreamer(tokenizer)
output = qeff_model.generate(inputs=inputs, generation_len=100)
Expand Down Expand Up @@ -148,29 +127,9 @@
padding=True,
return_tensors="pt",
)
input_ids_length = inputs["input_ids"].shape[1]

inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1)

pos_ids, rope_deltas = qeff_model.model.model.get_rope_index(
inputs["input_ids"],
inputs["image_grid_thw"],
video_grid_thw=None,
second_per_grid_ts=None,
attention_mask=inputs["attention_mask"],
)

inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0)

prefill_seq_len = 128
num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len

inputs["position_ids"] = F.pad(
inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1
)
inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size)

inputs.pop("image_grid_thw")
streamer = TextStreamer(tokenizer)
output = qeff_model.generate(inputs=inputs, generation_len=100)
print(output.generated_ids)
Expand Down
15 changes: 15 additions & 0 deletions tests/transformers/models/test_image_text_to_text_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@
"Can you describe the image in detail.",
1,
),
(
"Qwen/Qwen2.5-VL-3B-Instruct",
True,
1,
128,
4096,
1540,
"https://picsum.photos/id/237/536/354",
"Can you describe the image in detail.",
1,
),
# (
# "meta-llama/Llama-3.2-11B-Vision-Instruct",
# True,
Expand Down Expand Up @@ -320,6 +331,10 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
qnn_config=qnn_config,
)
inputs = processor(images=image, text=prompt, return_tensors="pt")
if hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl":
inputs = qeff_model.model.prepare_inputs_for_generation(
inputs=inputs, prefill_seq_len=prompt_len, batch_size=batch_size
)
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
print("QPC Outputs (QAIC):")
Expand Down
Loading