-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
108 lines (97 loc) · 4.66 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from PIL import Image
import torch
import fire
from processing_paligemma import PaliGemmaProcessor
from gemma_decoder import KVCache, PaliGemmaForConditionalGeneration
from utils import load_hf_model
def move_inputs_to_device(model_inputs: dict, device: str):
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
return model_inputs
def get_model_inputs(processor: PaliGemmaProcessor, prompt: str, image_file_path: str, device: str):
image = Image.open(image_file_path)
images = [image]
prompts = [prompt]
model_inputs = processor(text=prompts, images=images)
model_inputs = move_inputs_to_device(model_inputs, device)
return model_inputs
def _sample_top_p(probs: torch.Tensor, p: float):
# (B, vocab_size)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# (B, vocab_size)
probs_sum = torch.cumsum(probs_sort, dim=-1)
# (B, vocab_size)
# (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
mask = probs_sum - probs_sort > p
# Zero out all the probabilities of tokens that are not selected by the Top P
probs_sort[mask] = 0.0
# Redistribute the probabilities so that they sum up to 1.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
# Sample a token (its index) from the top p distribution
next_token = torch.multinomial(probs_sort, num_samples=1)
# Get the token position in the vocabulary corresponding to the sampled index
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def test_inference(model: PaliGemmaForConditionalGeneration, processor: PaliGemmaProcessor, device: str, prompt: str, image_file_path: str, max_tokens_to_generate: int, temperature: float, top_p: float, do_sample: bool):
model_inputs = get_model_inputs(processor, prompt, image_file_path, device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
pixel_values = model_inputs["pixel_values"]
kv_cache = KVCache()
# Generate tokens until you see the stop token
stop_token = processor.tokenizer.eos_token_id
generated_tokens = []
for _ in range(max_tokens_to_generate):
# Get the model outputs, the first iteration of the loop is prefilling the input prompt and image
outputs = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, kv_cache=kv_cache)
kv_cache = outputs["kv_cache"]
next_token_logits = outputs["logits"][:, -1, :]
# Sample the next token
if do_sample:
# Apply temperature
next_token_logits = torch.softmax(next_token_logits / temperature, dim=-1)
next_token = _sample_top_p(next_token_logits, top_p)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # greedy
assert next_token.size() == (1, 1)
next_token = next_token.squeeze(0) # Remove batch dimension
generated_tokens.append(next_token)
# Stop if the stop token has been generated
if next_token.item() == stop_token:
break
# Append the next token to the input
input_ids = next_token.unsqueeze(-1)
attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1) # for the new generated token, mask it for the attention
generated_tokens = torch.cat(generated_tokens, dim=-1)
# Decode the generated tokens
decoded = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
print("The input is: " + prompt + "\n")
print("The response is: " + decoded)
def main(model_path: str = None, prompt: str = None, image_file_path: str = None, max_tokens_to_generate: int = 100, temperature: float = 0.8, top_p: float = 0.9, do_sample: bool = False, only_cpu: bool = False):
device = "cpu"
if not only_cpu:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
print("Device in use: ", device)
print(f"Loading model")
model, tokenizer = load_hf_model(model_path, device)
model = model.to(device).eval()
num_image_tokens = model.config.vision_config.num_image_tokens
image_size = model.config.vision_config.image_size
processor = PaliGemmaProcessor(tokenizer, num_image_tokens, image_size)
print("Running inference")
with torch.no_grad():
test_inference(
model,
processor,
device,
prompt,
image_file_path,
max_tokens_to_generate,
temperature,
top_p,
do_sample,
)
if __name__ == "__main__":
fire.Fire(main)