@@ -1125,14 +1125,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
1125
1125
self .config = config
1126
1126
self .multimodal_config = multimodal_config
1127
1127
self .use_data_parallel = multimodal_config .mm_encoder_tp_mode == "data"
1128
-
1129
- self .visual = Qwen3_VisionTransformer (
1130
- config .vision_config ,
1131
- norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
1132
- quant_config = quant_config ,
1133
- prefix = maybe_prefix (prefix , "visual" ),
1134
- use_data_parallel = self .use_data_parallel ,
1135
- )
1128
+ if not multimodal_config .get_limit_per_prompt ("image" ) and \
1129
+ not multimodal_config .get_limit_per_prompt ("video" ):
1130
+ self .visual = None
1131
+ else :
1132
+ self .visual = Qwen3_VisionTransformer (
1133
+ config .vision_config ,
1134
+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
1135
+ quant_config = quant_config ,
1136
+ prefix = maybe_prefix (prefix , "visual" ),
1137
+ use_data_parallel = self .use_data_parallel ,
1138
+ )
1136
1139
1137
1140
self .language_model = Qwen3LLMForCausalLM (vllm_config = vllm_config ,
1138
1141
prefix = maybe_prefix (
@@ -1148,11 +1151,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
1148
1151
config .vision_config .deepstack_visual_indexes
1149
1152
) if self .use_deepstack else 0
1150
1153
# register buffer for deepstack
1151
- self .deepstack_input_embeds = [
1152
- torch .zeros (vllm_config .scheduler_config .max_num_batched_tokens ,
1153
- config .text_config .hidden_size )
1154
- for _ in range (self .deepstack_num_level )
1155
- ] if self .use_deepstack else None
1154
+ if self .use_deepstack and self .visual is not None :
1155
+ self .deepstack_input_embeds = [
1156
+ torch .zeros (
1157
+ vllm_config .scheduler_config .max_num_batched_tokens ,
1158
+ config .text_config .hidden_size )
1159
+ for _ in range (self .deepstack_num_level )
1160
+ ]
1161
+ else :
1162
+ self .deepstack_input_embeds = None
1156
1163
self .visual_dim = config .vision_config .out_hidden_size
1157
1164
self .multiscale_dim = self .visual_dim * self .deepstack_num_level
1158
1165
@@ -1526,7 +1533,11 @@ def compute_logits(
1526
1533
1527
1534
def load_weights (self , weights : Iterable [tuple [str ,
1528
1535
torch .Tensor ]]) -> set [str ]:
1529
- loader = AutoWeightsLoader (self )
1536
+
1537
+ skip_prefixes = []
1538
+ if self .visual is None :
1539
+ skip_prefixes .extend (["visual." ])
1540
+ loader = AutoWeightsLoader (self , skip_prefixes = skip_prefixes )
1530
1541
return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
1531
1542
1532
1543
def get_mm_mapping (self ) -> MultiModelKeys :
0 commit comments