Skip to content

Commit 2b17ebd

Browse files
authored
Added support of 2qpcs for internvl and llava (#279)
Signed-off-by: Mohit Soni <[email protected]>
1 parent f8b5db6 commit 2b17ebd

File tree

5 files changed

+261
-55
lines changed

5 files changed

+261
-55
lines changed

QEfficient/transformers/models/internvl/modeling_internvl.py

Lines changed: 122 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,57 @@
1414
from QEfficient.utils.logging_utils import logger
1515

1616

17+
class QEffInternEncoderWrapper(nn.Module):
18+
def __init__(self, model):
19+
super().__init__()
20+
self.model = model
21+
22+
def forward(self, pixel_values):
23+
vit_embeds = self.model.extract_feature(pixel_values)
24+
return vit_embeds
25+
26+
27+
class QEffInternDecoderWrapper(nn.Module):
28+
def __init__(self, model):
29+
super().__init__()
30+
self.model = model
31+
self.config = self.model.language_model.config
32+
33+
def forward(self, input_ids, vit_embeds, position_ids, past_key_values):
34+
# TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models
35+
IMG_CONTEXT_TOKEN = 151667
36+
37+
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
38+
B, N, C = input_embeds.shape
39+
image_input_embeds = input_embeds.reshape(B * N, C)
40+
image_input_ids = input_ids.reshape(B * N)
41+
selected = image_input_ids == IMG_CONTEXT_TOKEN
42+
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
43+
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
44+
image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
45+
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
46+
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
47+
outputs = self.model.language_model(
48+
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
49+
)
50+
return outputs.logits, vit_embeds, outputs.past_key_values
51+
52+
1753
class QEffInternVLModel(nn.Module):
54+
def get_qeff_vision_encoder(self):
55+
return QEffInternEncoderWrapper(self)
56+
57+
def get_qeff_language_decoder(self):
58+
return QEffInternDecoderWrapper(self)
59+
1860
def get_specializations(
19-
self, batch_size: int, prefill_seq_len: int, ctx_len: int, img_size: int, **compiler_options
61+
self,
62+
batch_size: int,
63+
prefill_seq_len: int,
64+
ctx_len: int,
65+
img_size: int,
66+
kv_offload: bool = False,
67+
**compiler_options,
2068
):
2169
# TODO: check if this should be named num_patches or something else
2270
num_patches = compiler_options.pop("num_patches", None)
@@ -33,8 +81,18 @@ def get_specializations(
3381
elif img_size is None:
3482
img_size = 448
3583
logger.warning("Setting img_size to be 448, as it was neither passed nor found in vision_config")
36-
37-
specializations = [
84+
if img_size != 448 and kv_offload:
85+
raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.")
86+
vision = [
87+
{
88+
"batch_size": batch_size,
89+
"num_patches": num_patches,
90+
"img_size": img_size,
91+
"seq_len": prefill_seq_len,
92+
"ctx_len": ctx_len,
93+
}
94+
]
95+
lang = [
3896
{
3997
"batch_size": batch_size,
4098
"seq_len": prefill_seq_len,
@@ -50,61 +108,92 @@ def get_specializations(
50108
"img_size": img_size,
51109
},
52110
]
53-
return specializations, compiler_options
54111

55-
def get_onnx_dynamic_axes(
56-
self,
57-
):
112+
specializations = {}
113+
114+
if kv_offload:
115+
specializations["vision"] = vision
116+
specializations["lang"] = lang
117+
return specializations, compiler_options
118+
else:
119+
return lang, compiler_options
120+
121+
def get_onnx_dynamic_axes(self, kv_offload: bool = False):
58122
# Define dynamic axes
59-
dynamic_axes = {}
60-
dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
61-
dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
62-
dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}
123+
vision_dynamic_axes = {}
124+
lang_dynamic_axes = {}
125+
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
126+
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
127+
vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}
63128

64129
pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
65130
for i in range(self.language_model.config.num_hidden_layers):
66131
for kv in ["key", "value"]:
67-
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
132+
lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
68133

134+
dynamic_axes = {}
135+
if kv_offload:
136+
dynamic_axes["vision"] = vision_dynamic_axes
137+
dynamic_axes["lang"] = lang_dynamic_axes
138+
else:
139+
dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes}
69140
return dynamic_axes
70141

71-
def get_output_names(
72-
self,
73-
):
74-
output_names = ["logits", "pixel_values_RetainedState"]
142+
def get_output_names(self, kv_offload: bool = False):
143+
vision_output_names = ["vit_embeds"]
144+
lang_output_names = ["logits"]
75145
for i in range(self.language_model.config.num_hidden_layers):
76146
for kv in ["key", "value"]:
77-
output_names.append(f"past_{kv}.{i}_RetainedState")
147+
lang_output_names.append(f"past_{kv}.{i}_RetainedState")
148+
149+
output_names = {}
150+
if kv_offload:
151+
lang_output_names.insert(1, "vit_embeds_RetainedState")
152+
output_names["vision"] = vision_output_names
153+
output_names["lang"] = lang_output_names
154+
else:
155+
lang_output_names.insert(1, "pixel_values_RetainedState")
156+
return lang_output_names
78157
return output_names
79158

80159
def get_dummy_inputs(self, kv_offload: bool = False):
81-
if kv_offload:
82-
raise ValueError("kv_offload method not supported for InternVL yet!")
83160
num_patches = 13
84161
C = 3
85162
if vis_cfg := getattr(self.config, "vision_config", None):
86163
img_size = getattr(vis_cfg, "image_size", 448)
87164
else:
88165
img_size = 448
166+
if img_size != 448 and kv_offload:
167+
raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.")
168+
169+
# Taken from the modeling files of OpenGVLab/InternVL2_5-1B
170+
feature_size = int((((self.config.vision_config.hidden_size**0.5) * self.config.downsample_ratio) ** 2))
89171

90172
# Define shapes
91173
inputs_shapes = {}
92174
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
175+
inputs_shapes["vit_embeds"] = (
176+
num_patches,
177+
feature_size,
178+
self.language_model.config.hidden_size,
179+
)
93180
inputs_shapes["position_ids"] = (
94181
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
95182
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
96183
)
97184
inputs_shapes["pixel_values"] = (num_patches, C, img_size, img_size)
98185

99186
# Define inputs
100-
inputs = {}
101-
inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
102-
inputs["position_ids"] = (
187+
vision_inputs = {}
188+
lang_inputs = {}
189+
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
190+
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
191+
lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32)
192+
lang_inputs["position_ids"] = (
103193
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
104194
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
105195
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
106196
)
107-
inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
108197

109198
# Add data for KV
110199
kv_cache_shape = get_padding_shape_from_config(
@@ -113,10 +202,18 @@ def get_dummy_inputs(self, kv_offload: bool = False):
113202
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
114203
)
115204

116-
inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
205+
lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
117206
for i in range(self.language_model.config.num_hidden_layers):
118207
for kv in ["key", "value"]:
119-
inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
208+
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
209+
210+
inputs = {}
211+
if kv_offload:
212+
inputs["vision"] = vision_inputs
213+
inputs["lang"] = lang_inputs
214+
else:
215+
lang_inputs.pop("vit_embeds")
216+
inputs = {**vision_inputs, **lang_inputs}
120217

121218
return inputs
122219

0 commit comments

Comments
 (0)