-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Paligemma Multimodal AI - inference tested on HF 3B Params
- Loading branch information
0 parents
commit d914130
Showing
8 changed files
with
918 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from PIL import Image | ||
import torch | ||
import fire | ||
|
||
from processing_paligemma import PaliGemmaProcessor | ||
from gemma_decoder import KVCache, PaliGemmaForConditionalGeneration | ||
from utils import load_hf_model | ||
|
||
def move_inputs_to_device(model_inputs: dict, device: str): | ||
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} | ||
return model_inputs | ||
|
||
|
||
def get_model_inputs(processor: PaliGemmaProcessor, prompt: str, image_file_path: str, device: str): | ||
image = Image.open(image_file_path) | ||
images = [image] | ||
prompts = [prompt] | ||
model_inputs = processor(text=prompts, images=images) | ||
model_inputs = move_inputs_to_device(model_inputs, device) | ||
return model_inputs | ||
|
||
def _sample_top_p(probs: torch.Tensor, p: float): | ||
# (B, vocab_size) | ||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | ||
# (B, vocab_size) | ||
probs_sum = torch.cumsum(probs_sort, dim=-1) | ||
# (B, vocab_size) | ||
# (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking) | ||
mask = probs_sum - probs_sort > p | ||
# Zero out all the probabilities of tokens that are not selected by the Top P | ||
probs_sort[mask] = 0.0 | ||
# Redistribute the probabilities so that they sum up to 1. | ||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | ||
# Sample a token (its index) from the top p distribution | ||
next_token = torch.multinomial(probs_sort, num_samples=1) | ||
# Get the token position in the vocabulary corresponding to the sampled index | ||
next_token = torch.gather(probs_idx, -1, next_token) | ||
return next_token | ||
|
||
def test_inference(model: PaliGemmaForConditionalGeneration, processor: PaliGemmaProcessor, device: str, prompt: str, image_file_path: str, max_tokens_to_generate: int, temperature: float, top_p: float, do_sample: bool): | ||
model_inputs = get_model_inputs(processor, prompt, image_file_path, device) | ||
input_ids = model_inputs["input_ids"] | ||
attention_mask = model_inputs["attention_mask"] | ||
pixel_values = model_inputs["pixel_values"] | ||
kv_cache = KVCache() | ||
# Generate tokens until you see the stop token | ||
stop_token = processor.tokenizer.eos_token_id | ||
generated_tokens = [] | ||
for _ in range(max_tokens_to_generate): | ||
# Get the model outputs, the first iteration of the loop is prefilling the input prompt and image | ||
outputs = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, kv_cache=kv_cache) | ||
kv_cache = outputs["kv_cache"] | ||
next_token_logits = outputs["logits"][:, -1, :] | ||
# Sample the next token | ||
if do_sample: | ||
# Apply temperature | ||
next_token_logits = torch.softmax(next_token_logits / temperature, dim=-1) | ||
next_token = _sample_top_p(next_token_logits, top_p) | ||
else: | ||
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # greedy | ||
assert next_token.size() == (1, 1) | ||
next_token = next_token.squeeze(0) # Remove batch dimension | ||
generated_tokens.append(next_token) | ||
# Stop if the stop token has been generated | ||
if next_token.item() == stop_token: | ||
break | ||
# Append the next token to the input | ||
input_ids = next_token.unsqueeze(-1) | ||
attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1) # for the new generated token, mask it for the attention | ||
generated_tokens = torch.cat(generated_tokens, dim=-1) | ||
# Decode the generated tokens | ||
decoded = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) | ||
print("The input is: " + prompt + "\n") | ||
print("The response is: " + decoded) | ||
|
||
def main(model_path: str = None, prompt: str = None, image_file_path: str = None, max_tokens_to_generate: int = 100, temperature: float = 0.8, top_p: float = 0.9, do_sample: bool = False, only_cpu: bool = False): | ||
device = "cpu" | ||
if not only_cpu: | ||
if torch.cuda.is_available(): | ||
device = "cuda" | ||
elif torch.backends.mps.is_available(): | ||
device = "mps" | ||
|
||
print("Device in use: ", device) | ||
print(f"Loading model") | ||
model, tokenizer = load_hf_model(model_path, device) | ||
model = model.to(device).eval() | ||
num_image_tokens = model.config.vision_config.num_image_tokens | ||
image_size = model.config.vision_config.image_size | ||
processor = PaliGemmaProcessor(tokenizer, num_image_tokens, image_size) | ||
|
||
print("Running inference") | ||
with torch.no_grad(): | ||
test_inference( | ||
model, | ||
processor, | ||
device, | ||
prompt, | ||
image_file_path, | ||
max_tokens_to_generate, | ||
temperature, | ||
top_p, | ||
do_sample, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Dict, List, Optional, Union, Tuple, Iterable | ||
import numpy as np | ||
from PIL import Image | ||
import torch | ||
|
||
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] | ||
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] | ||
|
||
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_len, image_token): | ||
# bos is beginning of sentence token between image tokens and text prompt tokens and then new line as the separator token. Acually the separator token has to token separately but hugging face weights follow like this. TO BE NOTE D | ||
return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n" | ||
|
||
def resize(image: Image, size: Tuple[int, int], resample: Image.Resampling = None, reducing_gap: Optional[int] = None) -> np.ndarray: | ||
height, width = size | ||
resized_image = image.resize( | ||
(width, height), resample=resample, reducing_gap=reducing_gap | ||
) | ||
return resized_image | ||
|
||
def rescale(image: np.ndarray, scale: float, dtype: np.dtype = np.float32) -> np.ndarray: | ||
rescaled_image = image * scale | ||
rescaled_image = rescaled_image.astype(dtype) | ||
return rescaled_image | ||
|
||
def normalize(image: np.ndarray, mean: Union[float, Iterable[float]], std: Union[float, Iterable[float]]) -> np.ndarray: | ||
mean = np.array(mean, dtype=image.dtype) | ||
std = np.array(std, dtype=image.dtype) | ||
image = (image - mean) / std | ||
return image | ||
|
||
def process_images(images: List[Image.Image], size: Dict[str, int] = None, resample: Image.Resampling = None, rescale_factor: float = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None) -> List[np.ndarray]: | ||
height, width = size[0], size[1] | ||
images = [ | ||
resize(image=image, size=(height, width), resample=resample) for image in images | ||
] | ||
# Convert each image to a numpy array | ||
images = [np.array(image) for image in images] | ||
# Rescale the pixel values to be in the range [0, 1] | ||
images = [rescale(image, scale=rescale_factor) for image in images] | ||
# Normalize the images to have mean 0 and standard deviation 1 | ||
images = [normalize(image, mean=image_mean, std=image_std) for image in images] | ||
# Move the channel dimension to the first dimension. The model expects images in the format [Channel, Height, Width] | ||
images = [image.transpose(2, 0, 1) for image in images] | ||
return images | ||
|
||
class PaliGemmaProcessor: | ||
|
||
IMAGE_TOKEN = "<image>" | ||
|
||
def __init__(self, tokenizer, num_image_tokens: int, image_size: int): | ||
super().__init__() | ||
self.image_seq_length = num_image_tokens | ||
self.image_size = image_size | ||
|
||
# Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer | ||
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]} | ||
tokenizer.add_special_tokens(tokens_to_add) | ||
EXTRA_TOKENS = [ | ||
f"<loc{i:04d}>" for i in range(1024) | ||
] # These tokens are used for object detection (bounding boxes) | ||
EXTRA_TOKENS += [ | ||
f"<seg{i:03d}>" for i in range(128) | ||
] # These tokens are used for object segmentation | ||
tokenizer.add_tokens(EXTRA_TOKENS) | ||
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) | ||
# We will add the BOS and EOS tokens ourselves | ||
tokenizer.add_bos_token = False | ||
tokenizer.add_eos_token = False | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, text: List[str], images: List[Image.Image], padding: str = "longest", truncation: bool = True) -> dict: | ||
assert len(images) == 1 and len(text) == 1, f"Received {len(images)} images for {len(text)} prompts." | ||
pixel_values = process_images( | ||
images, | ||
size=(self.image_size, self.image_size), | ||
resample=Image.Resampling.BICUBIC, | ||
rescale_factor= 1/255.0, | ||
image_mean=IMAGENET_STANDARD_MEAN, | ||
image_std=IMAGENET_STANDARD_STD, | ||
) | ||
# Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width] | ||
pixel_values = np.stack(pixel_values, axis=0) | ||
# Convert the numpy array to a PyTorch tensor | ||
pixel_values = torch.tensor(pixel_values) | ||
|
||
# Prepend a `self.image_seq_length` number of image tokens to the prompt | ||
input_strings = [ | ||
add_image_tokens_to_prompt( | ||
prefix_prompt=prompt, | ||
bos_token=self.tokenizer.bos_token, | ||
image_seq_len=self.image_seq_length, | ||
image_token=self.IMAGE_TOKEN, | ||
) | ||
for prompt in text | ||
] | ||
|
||
# Returns the input_ids and attention_mask as PyTorch tensors | ||
# Eg: HELLO WORLD -> [5, 2, 3] these are input_ids. Later we embed these to form something like [[...],[...],[...]] | ||
inputs = self.tokenizer( | ||
input_strings, | ||
return_tensors="pt", | ||
padding=padding, | ||
truncation=truncation, | ||
) | ||
|
||
return_data = {"pixel_values": pixel_values, **inputs} | ||
return return_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
torch | ||
transformers | ||
numpy | ||
pillow | ||
safetensors | ||
tokenizers | ||
torchaudio | ||
torchvision | ||
tqdm | ||
fire |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#!/bin/bash | ||
|
||
MODEL_PATH="$HOME/multi-modal-ai/paligemma/paligemma-3b-pt-224" | ||
PROMPT="describe the image " | ||
IMAGE_FILE_PATH="./image.jpg" | ||
MAX_TOKENS_TO_GENERATE=100 | ||
TEMPERATURE=0.8 | ||
TOP_P=0.9 | ||
DO_SAMPLE="False" | ||
ONLY_CPU="True" | ||
|
||
python inference.py \ | ||
--model_path "$MODEL_PATH" \ | ||
--prompt "$PROMPT" \ | ||
--image_file_path "$IMAGE_FILE_PATH" \ | ||
--max_tokens_to_generate $MAX_TOKENS_TO_GENERATE \ | ||
--temperature $TEMPERATURE \ | ||
--top_p $TOP_P \ | ||
--do_sample $DO_SAMPLE \ | ||
--only_cpu $ONLY_CPU \ |
Oops, something went wrong.