Skip to content

Commit

Permalink
Paligemma Multimodal AI - inference tested on HF 3B Params
Browse files Browse the repository at this point in the history
  • Loading branch information
vamshikrishnakyatham committed Jan 27, 2025
0 parents commit d914130
Show file tree
Hide file tree
Showing 8 changed files with 918 additions and 0 deletions.
425 changes: 425 additions & 0 deletions gemma_decoder.py

Large diffs are not rendered by default.

Binary file added image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
108 changes: 108 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from PIL import Image
import torch
import fire

from processing_paligemma import PaliGemmaProcessor
from gemma_decoder import KVCache, PaliGemmaForConditionalGeneration
from utils import load_hf_model

def move_inputs_to_device(model_inputs: dict, device: str):
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
return model_inputs


def get_model_inputs(processor: PaliGemmaProcessor, prompt: str, image_file_path: str, device: str):
image = Image.open(image_file_path)
images = [image]
prompts = [prompt]
model_inputs = processor(text=prompts, images=images)
model_inputs = move_inputs_to_device(model_inputs, device)
return model_inputs

def _sample_top_p(probs: torch.Tensor, p: float):
# (B, vocab_size)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# (B, vocab_size)
probs_sum = torch.cumsum(probs_sort, dim=-1)
# (B, vocab_size)
# (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
mask = probs_sum - probs_sort > p
# Zero out all the probabilities of tokens that are not selected by the Top P
probs_sort[mask] = 0.0
# Redistribute the probabilities so that they sum up to 1.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
# Sample a token (its index) from the top p distribution
next_token = torch.multinomial(probs_sort, num_samples=1)
# Get the token position in the vocabulary corresponding to the sampled index
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

def test_inference(model: PaliGemmaForConditionalGeneration, processor: PaliGemmaProcessor, device: str, prompt: str, image_file_path: str, max_tokens_to_generate: int, temperature: float, top_p: float, do_sample: bool):
model_inputs = get_model_inputs(processor, prompt, image_file_path, device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
pixel_values = model_inputs["pixel_values"]
kv_cache = KVCache()
# Generate tokens until you see the stop token
stop_token = processor.tokenizer.eos_token_id
generated_tokens = []
for _ in range(max_tokens_to_generate):
# Get the model outputs, the first iteration of the loop is prefilling the input prompt and image
outputs = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, kv_cache=kv_cache)
kv_cache = outputs["kv_cache"]
next_token_logits = outputs["logits"][:, -1, :]
# Sample the next token
if do_sample:
# Apply temperature
next_token_logits = torch.softmax(next_token_logits / temperature, dim=-1)
next_token = _sample_top_p(next_token_logits, top_p)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # greedy
assert next_token.size() == (1, 1)
next_token = next_token.squeeze(0) # Remove batch dimension
generated_tokens.append(next_token)
# Stop if the stop token has been generated
if next_token.item() == stop_token:
break
# Append the next token to the input
input_ids = next_token.unsqueeze(-1)
attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1) # for the new generated token, mask it for the attention
generated_tokens = torch.cat(generated_tokens, dim=-1)
# Decode the generated tokens
decoded = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
print("The input is: " + prompt + "\n")
print("The response is: " + decoded)

def main(model_path: str = None, prompt: str = None, image_file_path: str = None, max_tokens_to_generate: int = 100, temperature: float = 0.8, top_p: float = 0.9, do_sample: bool = False, only_cpu: bool = False):
device = "cpu"
if not only_cpu:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"

print("Device in use: ", device)
print(f"Loading model")
model, tokenizer = load_hf_model(model_path, device)
model = model.to(device).eval()
num_image_tokens = model.config.vision_config.num_image_tokens
image_size = model.config.vision_config.image_size
processor = PaliGemmaProcessor(tokenizer, num_image_tokens, image_size)

print("Running inference")
with torch.no_grad():
test_inference(
model,
processor,
device,
prompt,
image_file_path,
max_tokens_to_generate,
temperature,
top_p,
do_sample,
)


if __name__ == "__main__":
fire.Fire(main)
107 changes: 107 additions & 0 deletions processing_paligemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Dict, List, Optional, Union, Tuple, Iterable
import numpy as np
from PIL import Image
import torch

IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]

def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_len, image_token):
# bos is beginning of sentence token between image tokens and text prompt tokens and then new line as the separator token. Acually the separator token has to token separately but hugging face weights follow like this. TO BE NOTE D
return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n"

def resize(image: Image, size: Tuple[int, int], resample: Image.Resampling = None, reducing_gap: Optional[int] = None) -> np.ndarray:
height, width = size
resized_image = image.resize(
(width, height), resample=resample, reducing_gap=reducing_gap
)
return resized_image

def rescale(image: np.ndarray, scale: float, dtype: np.dtype = np.float32) -> np.ndarray:
rescaled_image = image * scale
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image

def normalize(image: np.ndarray, mean: Union[float, Iterable[float]], std: Union[float, Iterable[float]]) -> np.ndarray:
mean = np.array(mean, dtype=image.dtype)
std = np.array(std, dtype=image.dtype)
image = (image - mean) / std
return image

def process_images(images: List[Image.Image], size: Dict[str, int] = None, resample: Image.Resampling = None, rescale_factor: float = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None) -> List[np.ndarray]:
height, width = size[0], size[1]
images = [
resize(image=image, size=(height, width), resample=resample) for image in images
]
# Convert each image to a numpy array
images = [np.array(image) for image in images]
# Rescale the pixel values to be in the range [0, 1]
images = [rescale(image, scale=rescale_factor) for image in images]
# Normalize the images to have mean 0 and standard deviation 1
images = [normalize(image, mean=image_mean, std=image_std) for image in images]
# Move the channel dimension to the first dimension. The model expects images in the format [Channel, Height, Width]
images = [image.transpose(2, 0, 1) for image in images]
return images

class PaliGemmaProcessor:

IMAGE_TOKEN = "<image>"

def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
super().__init__()
self.image_seq_length = num_image_tokens
self.image_size = image_size

# Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
tokenizer.add_special_tokens(tokens_to_add)
EXTRA_TOKENS = [
f"<loc{i:04d}>" for i in range(1024)
] # These tokens are used for object detection (bounding boxes)
EXTRA_TOKENS += [
f"<seg{i:03d}>" for i in range(128)
] # These tokens are used for object segmentation
tokenizer.add_tokens(EXTRA_TOKENS)
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
# We will add the BOS and EOS tokens ourselves
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.tokenizer = tokenizer

def __call__(self, text: List[str], images: List[Image.Image], padding: str = "longest", truncation: bool = True) -> dict:
assert len(images) == 1 and len(text) == 1, f"Received {len(images)} images for {len(text)} prompts."
pixel_values = process_images(
images,
size=(self.image_size, self.image_size),
resample=Image.Resampling.BICUBIC,
rescale_factor= 1/255.0,
image_mean=IMAGENET_STANDARD_MEAN,
image_std=IMAGENET_STANDARD_STD,
)
# Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width]
pixel_values = np.stack(pixel_values, axis=0)
# Convert the numpy array to a PyTorch tensor
pixel_values = torch.tensor(pixel_values)

# Prepend a `self.image_seq_length` number of image tokens to the prompt
input_strings = [
add_image_tokens_to_prompt(
prefix_prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=self.IMAGE_TOKEN,
)
for prompt in text
]

# Returns the input_ids and attention_mask as PyTorch tensors
# Eg: HELLO WORLD -> [5, 2, 3] these are input_ids. Later we embed these to form something like [[...],[...],[...]]
inputs = self.tokenizer(
input_strings,
return_tensors="pt",
padding=padding,
truncation=truncation,
)

return_data = {"pixel_values": pixel_values, **inputs}
return return_data
10 changes: 10 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
torch
transformers
numpy
pillow
safetensors
tokenizers
torchaudio
torchvision
tqdm
fire
20 changes: 20 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

MODEL_PATH="$HOME/multi-modal-ai/paligemma/paligemma-3b-pt-224"
PROMPT="describe the image "
IMAGE_FILE_PATH="./image.jpg"
MAX_TOKENS_TO_GENERATE=100
TEMPERATURE=0.8
TOP_P=0.9
DO_SAMPLE="False"
ONLY_CPU="True"

python inference.py \
--model_path "$MODEL_PATH" \
--prompt "$PROMPT" \
--image_file_path "$IMAGE_FILE_PATH" \
--max_tokens_to_generate $MAX_TOKENS_TO_GENERATE \
--temperature $TEMPERATURE \
--top_p $TOP_P \
--do_sample $DO_SAMPLE \
--only_cpu $ONLY_CPU \
Loading

0 comments on commit d914130

Please sign in to comment.