From d17d19e5b84ec459e8fcce238232781a731ca488 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 17 Oct 2024 10:27:26 -0700 Subject: [PATCH] Fix mixed batch for multi modal models (#1702) --- .github/workflows/pr-test.yml | 4 +- python/sglang/srt/models/llava.py | 5 +-- test/srt/test_vision_openai_server.py | 55 +++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 368a984ad65..baca881df0a 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -76,7 +76,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 5 --range-end 16 + python3 run_suite.py --suite minimal --range-begin 5 --range-end 17 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -96,7 +96,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 16 + python3 run_suite.py --suite minimal --range-begin 17 performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 0ee11489299..beeab5679e7 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -160,9 +160,6 @@ def forward( image_sizes = [ image_inputs[i].image_sizes for i in range(bs) if need_vision[i] ] - image_offsets = [ - image_inputs[i].image_offsets for i in range(bs) if need_vision[i] - ] ########## Encode Image ######## @@ -358,7 +355,7 @@ def forward( prefix_len = prefix_lens_cpu[i] # Multiple images - for j, image_offset in enumerate(image_offsets[i]): + for j, image_offset in enumerate(image_inputs[i].image_offsets): if image_offset < prefix_len: continue diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 727f5774cad..6d57dfd2c30 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -1,8 +1,14 @@ +""" +Usage: +python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch +""" + import base64 import io import json import os import unittest +from concurrent.futures import ThreadPoolExecutor import numpy as np import openai @@ -288,6 +294,55 @@ def test_regex(self): assert isinstance(js_obj["color"], str) assert isinstance(js_obj["number_of_cars"], int) + def run_decode_with_image(self, image_id): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + content = [] + if image_id == 0: + content.append( + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + }, + } + ) + elif image_id == 1: + content.append( + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + } + ) + else: + pass + + content.append( + { + "type": "text", + "text": "Describe this image in a very short sentence.", + } + ) + + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "user", "content": content}, + ], + temperature=0, + ) + + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + + def test_mixed_batch(self): + image_ids = [0, 1, 2] * 4 + with ThreadPoolExecutor(4) as executor: + list(executor.map(self.run_decode_with_image, image_ids)) + if __name__ == "__main__": unittest.main()