Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,14 +1125,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
if not multimodal_config.get_limit_per_prompt("image") and \
not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
Expand All @@ -1148,11 +1151,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
self.deepstack_input_embeds = [
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level

Expand Down Expand Up @@ -1526,7 +1533,11 @@ def compute_logits(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)

skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

def get_mm_mapping(self) -> MultiModelKeys:
Expand Down
32 changes: 20 additions & 12 deletions vllm/model_executor/models/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
if not multimodal_config.get_limit_per_prompt("image") and \
not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
Expand All @@ -341,10 +345,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
self.deepstack_input_embeds = [
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level