Skip to content

Commit

Permalink
Fix mixed batch for multi modal models (#1702)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 17, 2024
1 parent dd3809f commit d17d19e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 16
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
Expand All @@ -96,7 +96,7 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 16
python3 run_suite.py --suite minimal --range-begin 17
performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
Expand Down
5 changes: 1 addition & 4 deletions python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def forward(
image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
]
image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
]

########## Encode Image ########

Expand Down Expand Up @@ -358,7 +355,7 @@ def forward(
prefix_len = prefix_lens_cpu[i]

# Multiple images
for j, image_offset in enumerate(image_offsets[i]):
for j, image_offset in enumerate(image_inputs[i].image_offsets):
if image_offset < prefix_len:
continue

Expand Down
55 changes: 55 additions & 0 deletions test/srt/test_vision_openai_server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""
Usage:
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
"""

import base64
import io
import json
import os
import unittest
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import openai
Expand Down Expand Up @@ -288,6 +294,55 @@ def test_regex(self):
assert isinstance(js_obj["color"], str)
assert isinstance(js_obj["number_of_cars"], int)

def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)

content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
}
)
else:
pass

content.append(
{
"type": "text",
"text": "Describe this image in a very short sentence.",
}
)

response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": content},
],
temperature=0,
)

assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)

def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))


if __name__ == "__main__":
unittest.main()

0 comments on commit d17d19e

Please sign in to comment.