Skip to content

Commit

Permalink
feat: Support VLM in reference_hf (#2726)
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <[email protected]>
  • Loading branch information
gaocegege authored Jan 3, 2025
1 parent afdee7b commit f5d0865
Showing 1 changed file with 84 additions and 3 deletions.
87 changes: 84 additions & 3 deletions scripts/playground/reference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,89 @@

import argparse

import requests
from PIL import Image

import torch
from transformers import AutoModelForCausalLM
from transformers import (
AutoModelForCausalLM, AutoProcessor, AutoModelForImageTextToText
)

from sglang.srt.hf_transformers_utils import get_tokenizer


@torch.no_grad()
def vlm_text_with_image(args):
# Load the processor and model for ImageTextToText tasks
processor = AutoProcessor.from_pretrained(
args.model_path, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)

torch.cuda.set_device(0)

# List of image URLs to process
image_urls = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]

# Conversation template for the processor
conversation = [
{
"role": "user",
"content": [
{
"type": "image",
},
{
"type": "text",
"text": "Describe this image."
}
]
}
]

max_new_tokens = args.max_new_tokens

for i, url in enumerate(image_urls):
# Load the image from the URL
image = Image.open(requests.get(url, stream=True).raw)

# Apply the chat template to the text prompt
# Notice that not all processors support chat templates.
# LLaVA and QWen are two processors that support chat templates.
if not hasattr(processor, "apply_chat_template"):
raise ValueError("The processor does not support chat templates.")
text_prompt = processor.apply_chat_template(
conversation, add_generation_prompt=True)

# Prepare inputs for the model
inputs = processor(text=[text_prompt], images=[image],
return_tensors="pt").to("cuda:0")

# Generate output from the model
output_ids = model.generate(
**inputs, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = processor.decode(output_ids[0])

# Get the logits from the model's forward pass
outputs = model.forward(**inputs)
logits = outputs.logits[0, -1, :]

print(f"\n========== Image {i} ==========")
print("prefill logits (final)", logits)
# TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
# making it cluttered and difficult to read.
# These tokens should be removed or cleaned up for better readability.
print(output_str)


@torch.no_grad()
def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True)
Expand Down Expand Up @@ -108,7 +185,11 @@ def synthetic_tokens(args):

parser.add_argument("--dtype", type=str, default="float16")

parser.add_argument("--model-type", type=str, default="text")

args = parser.parse_args()

normal_text(args)
# synthetic_tokens(args)
if args.model_type == "vlm":
vlm_text_with_image(args)
else:
normal_text(args)

0 comments on commit f5d0865

Please sign in to comment.