@@ -1126,14 +1126,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
11261126 self .config = config
11271127 self .multimodal_config = multimodal_config
11281128 self .use_data_parallel = multimodal_config .mm_encoder_tp_mode == "data"
1129-
1130- self .visual = Qwen3_VisionTransformer (
1131- config .vision_config ,
1132- norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
1133- quant_config = quant_config ,
1134- prefix = maybe_prefix (prefix , "visual" ),
1135- use_data_parallel = self .use_data_parallel ,
1136- )
1129+ if not multimodal_config .get_limit_per_prompt ("image" ) and \
1130+ not multimodal_config .get_limit_per_prompt ("video" ):
1131+ self .visual = None
1132+ else :
1133+ self .visual = Qwen3_VisionTransformer (
1134+ config .vision_config ,
1135+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
1136+ quant_config = quant_config ,
1137+ prefix = maybe_prefix (prefix , "visual" ),
1138+ use_data_parallel = self .use_data_parallel ,
1139+ )
11371140
11381141 self .language_model = Qwen3LLMForCausalLM (vllm_config = vllm_config ,
11391142 prefix = maybe_prefix (
@@ -1149,11 +1152,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
11491152 config .vision_config .deepstack_visual_indexes
11501153 ) if self .use_deepstack else 0
11511154 # register buffer for deepstack
1152- self .deepstack_input_embeds = [
1153- torch .zeros (vllm_config .scheduler_config .max_num_batched_tokens ,
1154- config .text_config .hidden_size )
1155- for _ in range (self .deepstack_num_level )
1156- ] if self .use_deepstack else None
1155+ if self .use_deepstack and self .visual is not None :
1156+ self .deepstack_input_embeds = [
1157+ torch .zeros (
1158+ vllm_config .scheduler_config .max_num_batched_tokens ,
1159+ config .text_config .hidden_size )
1160+ for _ in range (self .deepstack_num_level )
1161+ ]
1162+ else :
1163+ self .deepstack_input_embeds = None
11571164 self .visual_dim = config .vision_config .out_hidden_size
11581165 self .multiscale_dim = self .visual_dim * self .deepstack_num_level
11591166
@@ -1588,7 +1595,11 @@ def compute_logits(
15881595
15891596 def load_weights (self , weights : Iterable [tuple [str ,
15901597 torch .Tensor ]]) -> set [str ]:
1591- loader = AutoWeightsLoader (self )
1598+
1599+ skip_prefixes = []
1600+ if self .visual is None :
1601+ skip_prefixes .extend (["visual." ])
1602+ loader = AutoWeightsLoader (self , skip_prefixes = skip_prefixes )
15921603 return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
15931604
15941605 def get_mm_mapping (self ) -> MultiModelKeys :
0 commit comments