Skip to content

Commit

Permalink
[Bugfix] qwen2vl forward_extend (#1727)
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 authored Oct 20, 2024
1 parent b48edff commit 554fbf9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 38 deletions.
59 changes: 25 additions & 34 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional

import numpy as np
import torch

from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
Expand Down Expand Up @@ -134,16 +133,23 @@ def compute_mrope_positions(
)
elif self.forward_mode.is_extend():
for i, image_inputs in enumerate(batch.image_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
self.extend_start_loc[i],
self.extend_seq_lens[i],
self.extend_prefix_lens[i],
)
if image_inputs is None:
# text only
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3
mrope_positions = [
[
pos
for pos in range(
extend_prefix_len, extend_prefix_len + extend_seq_len
)
]
] * 3
mrope_position_delta = 0
else:
extend_start_loc, extend_seq_len, extend_prefix_len = (
self.extend_start_loc[i],
self.extend_seq_lens[i],
self.extend_prefix_lens[i],
)
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
Expand All @@ -163,12 +169,9 @@ def compute_mrope_positions(
mrope_positions_list[i] = mrope_positions
batch.mrope_positions_delta[i].append(mrope_position_delta)

self.mrope_positions = torch.tensor(
np.concatenate(
[np.array(pos) for pos in mrope_positions_list],
axis=1,
),
device=device,
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)

Expand All @@ -177,18 +180,15 @@ def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
if self.forward_mode.is_decode():
self.positions = (self.seq_lens - 1).to(torch.int64)
else:
self.positions = torch.tensor(
np.concatenate(
[
np.arange(prefix_len, prefix_len + extend_len)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
),
device=device,
).to(torch.int64)
self.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)

@classmethod
def init_new(
Expand All @@ -213,15 +213,6 @@ def init_new(

# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
Expand Down
4 changes: 0 additions & 4 deletions test/srt/test_vision_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,6 @@ def setUpClass(cls):
)
cls.base_url += "/v1"

def test_mixed_batch(self):
# FIXME: Temporarily skip this test.
pass


if __name__ == "__main__":
unittest.main()

0 comments on commit 554fbf9

Please sign in to comment.