From f5d0865b252ff9eb95cec73caa784194463ea03a Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Fri, 3 Jan 2025 22:32:30 +0800 Subject: [PATCH] feat: Support VLM in reference_hf (#2726) Signed-off-by: Ce Gao --- scripts/playground/reference_hf.py | 87 ++++++++++++++++++++++++++++-- 1 file changed, 84 insertions(+), 3 deletions(-) diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index 7901145c6d7..8f76948b684 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -25,12 +25,89 @@ import argparse +import requests +from PIL import Image + import torch -from transformers import AutoModelForCausalLM +from transformers import ( + AutoModelForCausalLM, AutoProcessor, AutoModelForImageTextToText +) from sglang.srt.hf_transformers_utils import get_tokenizer +@torch.no_grad() +def vlm_text_with_image(args): + # Load the processor and model for ImageTextToText tasks + processor = AutoProcessor.from_pretrained( + args.model_path, trust_remote_code=True) + model = AutoModelForImageTextToText.from_pretrained( + args.model_path, + torch_dtype=args.dtype, + low_cpu_mem_usage=True, + device_map="auto", + trust_remote_code=True, + ) + + torch.cuda.set_device(0) + + # List of image URLs to process + image_urls = [ + "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" + ] + + # Conversation template for the processor + conversation = [ + { + "role": "user", + "content": [ + { + "type": "image", + }, + { + "type": "text", + "text": "Describe this image." + } + ] + } + ] + + max_new_tokens = args.max_new_tokens + + for i, url in enumerate(image_urls): + # Load the image from the URL + image = Image.open(requests.get(url, stream=True).raw) + + # Apply the chat template to the text prompt + # Notice that not all processors support chat templates. + # LLaVA and QWen are two processors that support chat templates. + if not hasattr(processor, "apply_chat_template"): + raise ValueError("The processor does not support chat templates.") + text_prompt = processor.apply_chat_template( + conversation, add_generation_prompt=True) + + # Prepare inputs for the model + inputs = processor(text=[text_prompt], images=[image], + return_tensors="pt").to("cuda:0") + + # Generate output from the model + output_ids = model.generate( + **inputs, do_sample=False, max_new_tokens=max_new_tokens + ) + output_str = processor.decode(output_ids[0]) + + # Get the logits from the model's forward pass + outputs = model.forward(**inputs) + logits = outputs.logits[0, -1, :] + + print(f"\n========== Image {i} ==========") + print("prefill logits (final)", logits) + # TODO(gaocegege): The output contains numerous <|image_pad|> tokens, + # making it cluttered and difficult to read. + # These tokens should be removed or cleaned up for better readability. + print(output_str) + + @torch.no_grad() def normal_text(args): t = get_tokenizer(args.model_path, trust_remote_code=True) @@ -108,7 +185,11 @@ def synthetic_tokens(args): parser.add_argument("--dtype", type=str, default="float16") + parser.add_argument("--model-type", type=str, default="text") + args = parser.parse_args() - normal_text(args) - # synthetic_tokens(args) + if args.model_type == "vlm": + vlm_text_with_image(args) + else: + normal_text(args)