-
Notifications
You must be signed in to change notification settings - Fork 101
multimodal model embedding fixes #759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
libinta
wants to merge
42
commits into
vllm-project:main
Choose a base branch
from
libinta:libinta/remove_gather_scatter
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+184
−89
Open
Changes from 19 commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
9d8e272
Pick model runner change related to PR30475.
libinta 49d7633
add qwen3_vl.py functions
libinta bdff63f
Merge branch 'main' into libinta/remove_gather_scatter
libinta c6526de
precomit fix
libinta 7c6329e
precommit fix and fix use_window_sdpa
libinta bff3cf5
Update qwen3_vl.py
iboiko-habana 625d9c2
Update qwen3_vl.py
iboiko-habana 568b4eb
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana 495643a
Merge branch 'main' into libinta/remove_gather_scatter
libinta 327a9cc
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana bb3ac24
Update qwen3_vl.py
iboiko-habana fe67f98
Merge branch 'main' into libinta/remove_gather_scatter
libinta 8a9efd1
Merge branch 'main' into libinta/remove_gather_scatter
libinta a394b9a
Merge branch 'main' into libinta/remove_gather_scatter
libinta 6502061
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana 48a96db
fix test failure
libinta 0171641
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana db10548
fix precommit issue
libinta 40d7635
Update interfaces.py for precommit fix
libinta e23e6d2
Update hpu_model_runner.py to match with upstream for MultiModalBudget
libinta 46facad
Merge branch 'main' into libinta/remove_gather_scatter
libinta 4089adf
Update qwen3_vl.py for precommit fix
libinta 79d90a4
Update qwen3_vl.py for precommit fix
libinta e370a49
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana 0df1f20
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana 5fdf237
Merge branch 'main' into libinta/remove_gather_scatter
libinta 07f40c9
add back warmup with ratio and video warmup
libinta 9db6b78
Update ops.py with removing uncessary change
libinta 9be0056
Update hpu_model_runner.py for precommit fix
libinta 3ff7e80
Merge branch 'main' into libinta/remove_gather_scatter
libinta b4f2e6c
Update hpu_model_runner.py for precommit fix
libinta 02c239b
Update hpu_model_runner.py for precommit fix
libinta 3dd1f5c
Update hpu_model_runner.py for precommit fix
libinta 7757e80
Update hpu_model_runner.py for precommit fix
libinta 9097164
Merge branch 'main' into libinta/remove_gather_scatter
libinta 913176a
fix qwen2.5vl unified attn test failure
libinta 091c5fe
precommit fix
libinta f0613fd
precommit fix
libinta ec827b8
add more mm bucket
libinta 4cf5cb1
precommit fix
libinta 150cf7a
Merge branch 'main' into libinta/remove_gather_scatter
libinta f46b48d
Update qwen2.5-vl-7b.yaml to revert change
libinta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| import torch | ||
| from .utils import _merge_multimodal_embeddings | ||
| from vllm.model_executor.models.interfaces import MultiModalEmbeddings | ||
| from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration | ||
| from vllm.model_executor.models.interfaces import _require_is_multimodal | ||
|
|
||
|
|
||
| class HpuQwen3_VLForConditionalGeneration(Qwen3VLForConditionalGeneration): | ||
|
|
||
| def _compute_deepstack_embeds( | ||
| self, | ||
| inputs_embeds: torch.Tensor, | ||
| multimodal_embeddings: MultiModalEmbeddings, | ||
| is_multimodal: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, MultiModalEmbeddings]: | ||
| visual_lens = [len(x) for x in multimodal_embeddings] | ||
| multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) | ||
|
|
||
| ( | ||
| multimodal_embeddings_main, | ||
| multimodal_embeddings_multiscale, | ||
| ) = torch.split( | ||
| multimodal_embeddings_cat, | ||
| [self.visual_dim, self.multiscale_dim], | ||
| dim=-1, | ||
| ) | ||
|
|
||
| multimodal_embeddings = torch.split(multimodal_embeddings_main, visual_lens, dim=0) | ||
| multimodal_embeddings_multiscale = torch.split(multimodal_embeddings_multiscale, visual_lens, dim=0) | ||
|
|
||
| deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), | ||
| self.deepstack_num_level * inputs_embeds.size(1)) | ||
|
|
||
| deepstack_input_embeds = _merge_multimodal_embeddings( | ||
| inputs_embeds=deepstack_input_embeds, | ||
| multimodal_embeddings=multimodal_embeddings_multiscale, | ||
| is_multimodal=is_multimodal, | ||
| ) | ||
| deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], self.deepstack_num_level, | ||
| self.visual_dim) | ||
| deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) | ||
|
|
||
| return deepstack_input_embeds, multimodal_embeddings | ||
|
|
||
| def embed_input_ids( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| multimodal_embeddings: MultiModalEmbeddings | None = None, | ||
| *, | ||
| is_multimodal: torch.Tensor | None = None, | ||
| handle_oov_mm_token: bool = False, | ||
| ) -> torch.Tensor: | ||
| inputs_embeds = self._embed_text_input_ids( | ||
| input_ids, | ||
| self.language_model.embed_input_ids, | ||
| is_multimodal=is_multimodal, | ||
| handle_oov_mm_token=handle_oov_mm_token, | ||
| ) | ||
|
|
||
| if multimodal_embeddings is None or len(multimodal_embeddings) == 0: | ||
| return inputs_embeds | ||
|
|
||
| is_multimodal = _require_is_multimodal(is_multimodal) | ||
|
|
||
| if self.use_deepstack: | ||
| ( | ||
| deepstack_input_embeds, | ||
| multimodal_embeddings, | ||
| ) = self._compute_deepstack_embeds( | ||
| inputs_embeds=inputs_embeds, | ||
| multimodal_embeddings=multimodal_embeddings, | ||
| is_multimodal=is_multimodal, | ||
| ) | ||
| else: | ||
| deepstack_input_embeds = None | ||
|
|
||
| inputs_embeds = _merge_multimodal_embeddings( | ||
| inputs_embeds=inputs_embeds, | ||
| multimodal_embeddings=multimodal_embeddings, | ||
| is_multimodal=is_multimodal, | ||
| ) | ||
|
|
||
| if deepstack_input_embeds is not None: | ||
| self._set_deepstack_input_embeds(deepstack_input_embeds) | ||
|
|
||
| return inputs_embeds |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.