14
14
from QEfficient .utils .logging_utils import logger
15
15
16
16
17
+ class QEffInternEncoderWrapper (nn .Module ):
18
+ def __init__ (self , model ):
19
+ super ().__init__ ()
20
+ self .model = model
21
+
22
+ def forward (self , pixel_values ):
23
+ vit_embeds = self .model .extract_feature (pixel_values )
24
+ return vit_embeds
25
+
26
+
27
+ class QEffInternDecoderWrapper (nn .Module ):
28
+ def __init__ (self , model ):
29
+ super ().__init__ ()
30
+ self .model = model
31
+ self .config = self .model .language_model .config
32
+
33
+ def forward (self , input_ids , vit_embeds , position_ids , past_key_values ):
34
+ # TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models
35
+ IMG_CONTEXT_TOKEN = 151667
36
+
37
+ input_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
38
+ B , N , C = input_embeds .shape
39
+ image_input_embeds = input_embeds .reshape (B * N , C )
40
+ image_input_ids = input_ids .reshape (B * N )
41
+ selected = image_input_ids == IMG_CONTEXT_TOKEN
42
+ indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
43
+ indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
44
+ image_features_expanded = vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
45
+ image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
46
+ inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
47
+ outputs = self .model .language_model (
48
+ inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
49
+ )
50
+ return outputs .logits , vit_embeds , outputs .past_key_values
51
+
52
+
17
53
class QEffInternVLModel (nn .Module ):
54
+ def get_qeff_vision_encoder (self ):
55
+ return QEffInternEncoderWrapper (self )
56
+
57
+ def get_qeff_language_decoder (self ):
58
+ return QEffInternDecoderWrapper (self )
59
+
18
60
def get_specializations (
19
- self , batch_size : int , prefill_seq_len : int , ctx_len : int , img_size : int , ** compiler_options
61
+ self ,
62
+ batch_size : int ,
63
+ prefill_seq_len : int ,
64
+ ctx_len : int ,
65
+ img_size : int ,
66
+ kv_offload : bool = False ,
67
+ ** compiler_options ,
20
68
):
21
69
# TODO: check if this should be named num_patches or something else
22
70
num_patches = compiler_options .pop ("num_patches" , None )
@@ -33,8 +81,18 @@ def get_specializations(
33
81
elif img_size is None :
34
82
img_size = 448
35
83
logger .warning ("Setting img_size to be 448, as it was neither passed nor found in vision_config" )
36
-
37
- specializations = [
84
+ if img_size != 448 and kv_offload :
85
+ raise NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
86
+ vision = [
87
+ {
88
+ "batch_size" : batch_size ,
89
+ "num_patches" : num_patches ,
90
+ "img_size" : img_size ,
91
+ "seq_len" : prefill_seq_len ,
92
+ "ctx_len" : ctx_len ,
93
+ }
94
+ ]
95
+ lang = [
38
96
{
39
97
"batch_size" : batch_size ,
40
98
"seq_len" : prefill_seq_len ,
@@ -50,61 +108,92 @@ def get_specializations(
50
108
"img_size" : img_size ,
51
109
},
52
110
]
53
- return specializations , compiler_options
54
111
55
- def get_onnx_dynamic_axes (
56
- self ,
57
- ):
112
+ specializations = {}
113
+
114
+ if kv_offload :
115
+ specializations ["vision" ] = vision
116
+ specializations ["lang" ] = lang
117
+ return specializations , compiler_options
118
+ else :
119
+ return lang , compiler_options
120
+
121
+ def get_onnx_dynamic_axes (self , kv_offload : bool = False ):
58
122
# Define dynamic axes
59
- dynamic_axes = {}
60
- dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
61
- dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
62
- dynamic_axes ["pixel_values" ] = {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
123
+ vision_dynamic_axes = {}
124
+ lang_dynamic_axes = {}
125
+ lang_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
126
+ lang_dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
127
+ vision_dynamic_axes ["pixel_values" ] = {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
63
128
64
129
pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
65
130
for i in range (self .language_model .config .num_hidden_layers ):
66
131
for kv in ["key" , "value" ]:
67
- dynamic_axes [f"past_{ kv } .{ i } " ] = pkv_dynamic_axes
132
+ lang_dynamic_axes [f"past_{ kv } .{ i } " ] = pkv_dynamic_axes
68
133
134
+ dynamic_axes = {}
135
+ if kv_offload :
136
+ dynamic_axes ["vision" ] = vision_dynamic_axes
137
+ dynamic_axes ["lang" ] = lang_dynamic_axes
138
+ else :
139
+ dynamic_axes = {** vision_dynamic_axes , ** lang_dynamic_axes }
69
140
return dynamic_axes
70
141
71
- def get_output_names (
72
- self ,
73
- ):
74
- output_names = ["logits" , "pixel_values_RetainedState" ]
142
+ def get_output_names (self , kv_offload : bool = False ):
143
+ vision_output_names = ["vit_embeds" ]
144
+ lang_output_names = ["logits" ]
75
145
for i in range (self .language_model .config .num_hidden_layers ):
76
146
for kv in ["key" , "value" ]:
77
- output_names .append (f"past_{ kv } .{ i } _RetainedState" )
147
+ lang_output_names .append (f"past_{ kv } .{ i } _RetainedState" )
148
+
149
+ output_names = {}
150
+ if kv_offload :
151
+ lang_output_names .insert (1 , "vit_embeds_RetainedState" )
152
+ output_names ["vision" ] = vision_output_names
153
+ output_names ["lang" ] = lang_output_names
154
+ else :
155
+ lang_output_names .insert (1 , "pixel_values_RetainedState" )
156
+ return lang_output_names
78
157
return output_names
79
158
80
159
def get_dummy_inputs (self , kv_offload : bool = False ):
81
- if kv_offload :
82
- raise ValueError ("kv_offload method not supported for InternVL yet!" )
83
160
num_patches = 13
84
161
C = 3
85
162
if vis_cfg := getattr (self .config , "vision_config" , None ):
86
163
img_size = getattr (vis_cfg , "image_size" , 448 )
87
164
else :
88
165
img_size = 448
166
+ if img_size != 448 and kv_offload :
167
+ raise NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
168
+
169
+ # Taken from the modeling files of OpenGVLab/InternVL2_5-1B
170
+ feature_size = int ((((self .config .vision_config .hidden_size ** 0.5 ) * self .config .downsample_ratio ) ** 2 ))
89
171
90
172
# Define shapes
91
173
inputs_shapes = {}
92
174
inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
175
+ inputs_shapes ["vit_embeds" ] = (
176
+ num_patches ,
177
+ feature_size ,
178
+ self .language_model .config .hidden_size ,
179
+ )
93
180
inputs_shapes ["position_ids" ] = (
94
181
constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
95
182
constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
96
183
)
97
184
inputs_shapes ["pixel_values" ] = (num_patches , C , img_size , img_size )
98
185
99
186
# Define inputs
100
- inputs = {}
101
- inputs ["input_ids" ] = torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
102
- inputs ["position_ids" ] = (
187
+ vision_inputs = {}
188
+ lang_inputs = {}
189
+ vision_inputs ["pixel_values" ] = torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
190
+ lang_inputs ["input_ids" ] = torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
191
+ lang_inputs ["vit_embeds" ] = torch .zeros ((inputs_shapes ["vit_embeds" ]), dtype = torch .float32 )
192
+ lang_inputs ["position_ids" ] = (
103
193
torch .arange (constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN , dtype = torch .int64 )
104
194
.view (1 , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
105
195
.repeat (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , 1 )
106
196
)
107
- inputs ["pixel_values" ] = torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
108
197
109
198
# Add data for KV
110
199
kv_cache_shape = get_padding_shape_from_config (
@@ -113,10 +202,18 @@ def get_dummy_inputs(self, kv_offload: bool = False):
113
202
seq_len = constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
114
203
)
115
204
116
- inputs ["past_key_values" ] = [[] for _ in range (self .language_model .config .num_hidden_layers )]
205
+ lang_inputs ["past_key_values" ] = [[] for _ in range (self .language_model .config .num_hidden_layers )]
117
206
for i in range (self .language_model .config .num_hidden_layers ):
118
207
for kv in ["key" , "value" ]:
119
- inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
208
+ lang_inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
209
+
210
+ inputs = {}
211
+ if kv_offload :
212
+ inputs ["vision" ] = vision_inputs
213
+ inputs ["lang" ] = lang_inputs
214
+ else :
215
+ lang_inputs .pop ("vit_embeds" )
216
+ inputs = {** vision_inputs , ** lang_inputs }
120
217
121
218
return inputs
122
219
0 commit comments