Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to split_image #4

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aria/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class AriaModelConfig(ModelConfig):
"choices": [490, 980],
},
)
split_image: bool = field(
default=False,
metadata={"help": "Whether to split the image into smaller patches."},
)

def __post_init__(self):
super().__post_init__()
Expand Down
5 changes: 3 additions & 2 deletions aria/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import os
import warnings
from typing import Dict, List
from typing import Dict, Iterable, List

import torch
from datasets import DatasetDict, concatenate_datasets, load_dataset
Expand All @@ -29,6 +29,7 @@
def apply_chat_template_and_tokenize(
messages_batch: List[List[Dict]],
tokenizer,
num_image_crop: Iterable[torch.Tensor] = iter([]),
):
IGNORE_TOKEN_ID = -100
im_start_tokens = tokenizer("<|im_start|>").input_ids
Expand All @@ -41,7 +42,7 @@ def process_content(content):
if content["type"] == "text":
return content["text"]
elif content["type"] == "image":
return "<fim_prefix><|img|><fim_suffix>"
return "<fim_prefix>" + "<|img|>" * next(num_image_crop) + "<fim_suffix>"
else:
raise ValueError(f"Unknown content type {content['type']} in message")

Expand Down
22 changes: 19 additions & 3 deletions aria/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def parse_arguments():
help="Maximum size of the image to be processed",
default=980,
)
parser.add_argument(
"--split_image",
type=bool,
help="Whether to split the image into patches",
action="store_true",
default=False,
)
return parser.parse_args()


Expand All @@ -65,7 +72,9 @@ def load_model(base_model_path, peft_model_path=None):
return model


def prepare_input(image_path, prompt, processor: AriaProcessor, max_image_size):
def prepare_input(
image_path, prompt, processor: AriaProcessor, max_image_size, split_image
):
image = Image.open(image_path)

messages = [
Expand All @@ -85,6 +94,7 @@ def prepare_input(image_path, prompt, processor: AriaProcessor, max_image_size):
images=image,
return_tensors="pt",
max_image_size=max_image_size,
split_image=split_image,
)

return inputs
Expand All @@ -96,8 +106,9 @@ def inference(
model: AriaForConditionalGeneration,
processor: AriaProcessor,
max_image_size,
split_image,
):
inputs = prepare_input(image_path, prompt, processor, max_image_size)
inputs = prepare_input(image_path, prompt, processor, max_image_size, split_image)
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

Expand Down Expand Up @@ -126,7 +137,12 @@ def main():
model = load_model(args.base_model_path, args.peft_model_path)

result = inference(
args.image_path, args.prompt, model, processor, args.max_image_size
args.image_path,
args.prompt,
model,
processor,
args.max_image_size,
args.split_image,
)
print(result)

Expand Down
35 changes: 25 additions & 10 deletions aria/model/processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# under the License.

import inspect
import re
from typing import List, Optional, Union

from transformers import AutoTokenizer, BatchFeature
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
super().__init__(chat_template=chat_template)

if image_processor is None:
self.image_processor = AriaVisionProcessor(image_max_size=patch_size)
self.image_processor = AriaVisionProcessor(max_image_size=patch_size)
else:
self.image_processor = image_processor

Expand All @@ -87,6 +88,7 @@ def __call__(
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
max_image_size: Optional[int] = 980,
split_image: Optional[bool] = False,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Expand Down Expand Up @@ -114,6 +116,8 @@ def __call__(
Maximum length of the returned list and optionally padding length (see above).
max_image_size (`int`, *optional*):
Maximum size of the image to be processed.
split_image (`bool`, *optional*):
Whether to split the image into patches before processing.
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
Expand All @@ -134,24 +138,35 @@ def __call__(
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
"""
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError(
"Invalid input text. Please provide a string, or a list of strings"
)

if images is not None:
image_inputs = self.image_processor(
images,
return_tensors=return_tensors,
max_image_size=max_image_size,
split_image=split_image,
)
# expand the image_token according to the num_crops of image
prompt_strings = []
crop_iter = iter(image_inputs.pop("num_crops"))
for prompt in text:
prompt_strings.append(
re.sub(
re.escape(self.image_token),
lambda _: next(crop_iter) * self.image_token,
prompt,
)
)

else:
image_inputs = {}

if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError(
"Invalid input text. Please provide a string, or a list of strings"
)

prompt_strings = text

text_inputs = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
Expand Down
125 changes: 119 additions & 6 deletions aria/model/vision_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,93 @@

from typing import List, Optional, Union

import numpy as np
import torch
from PIL import Image, ImageOps
from torchvision import transforms
from transformers import BaseImageProcessor, BatchFeature, TensorType


def _select_best_resolution(
img_width: int, img_height: int, target_ratios: List[List[int]], patch_size: int
):
"""
Selects the best resolution from a list of possible resolutions based on the original size.

Args:
img_width: the original widths of images.
img_height: the original heights of images.
target_ratios (2d numpy array): dimension size (M,2)
patch_size (int): image patch size

Returns:
tuple: The best fit resolution in the format (width, height).
"""

aspect_ratio = img_width / img_height
best_ratio_diff = float("inf")
best_ratio_w, best_ratio_h = 1, 1
area = np.int32(img_height) * np.int32(img_height)
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio_w, best_ratio_h = ratio[0], ratio[1]
elif (
ratio_diff == best_ratio_diff
and area > 0.5 * patch_size * patch_size * ratio[0] * ratio[1]
):
best_ratio_w, best_ratio_h = ratio[0], ratio[1]

return best_ratio_w, best_ratio_h


def _split_image(
image: Image.Image,
split_image: bool,
split_ratio: List[List[int]],
patch_size: int,
) -> List[Image.Image]:
"""
Split image into multiple patches

Args:
image (PIL.Image): Input image.
split_image (bool): Whether to split the image into patches.
split_ratio (2d numpy array): dimension size (M,2)
patch_size (int): image patch size

Returns:
List[PIL.Image]: List of splitted images.
"""
if split_image:
ratio_width, ratio_height = _select_best_resolution(
image.width, image.height, split_ratio, patch_size
)
resize_width = patch_size * ratio_width
resize_height = patch_size * ratio_height
blocks = ratio_width * ratio_height
resized_img = image.resize((resize_width, resize_height))
processed_images = []
for i in range(blocks):
box = (
(i % (resize_width // patch_size)) * patch_size,
(i // (resize_width // patch_size)) * patch_size,
((i % (resize_width // patch_size)) + 1) * patch_size,
((i // (resize_width // patch_size)) + 1) * patch_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if len(processed_images) != 1:
processed_images.insert(0, image)
return processed_images
else:
return [image]


def keep_ratio_resize_and_pixel_mask(
img: Image.Image, max_size, min_size=336, padding_value=0
):
Expand Down Expand Up @@ -127,6 +208,17 @@ def __call__(
max_image_size: Optional[int] = 980,
min_image_size: Optional[int] = 336,
return_tensors: Optional[Union[str, TensorType]] = "pt",
split_image: Optional[bool] = False,
split_ratio: Optional[List[List[int]]] = [
[1, 1],
[1, 2],
[1, 3],
[1, 4],
[2, 2],
[2, 1],
[3, 1],
[4, 1],
],
):
"""
Process a list of images.
Expand All @@ -135,13 +227,16 @@ def __call__(
images (list): List of PIL.Image objects.
max_image_size (int, optional): Override the default max image size. Defaults to None.
return_tensors (str or TensorType, optional): The type of tensor to return. Defaults to "pt".
split_image (bool, optional): Whether to split the image. Defaults to False.
split_ratio (list, optional): The ratio for splitting the image. Defaults to a list of common split ratios.
Returns:
BatchFeature: A BatchFeature object containing:
- 'pixel_values': Tensor of processed image pixel values.
- 'pixel_mask': Boolean pixel mask. This mask is a 2D tensor of shape (max_size, max_size) where:
- True (1) values indicate pixels that belong to the original resized image.
- False (0) values indicate pixels that are part of the padding.
The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
- 'num_crops': Tensor of the number of crops for each image.
"""
max_size = self.max_image_size if max_image_size is None else max_image_size
min_size = self.min_image_size if min_image_size is None else min_image_size
Expand All @@ -154,19 +249,24 @@ def __call__(

pixel_values = []
pixel_masks = []
num_crops = []

for image in images:
img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask(
image, max_size, min_size
)
img_padded = self.transform(img_padded)
pixel_values.append(img_padded)
pixel_masks.append(pixel_mask)
crop_images = _split_image(image, split_image, split_ratio, max_size)
num_crops.append(torch.tensor(len(crop_images)))
for crop_image in crop_images:
img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask(
crop_image, max_size, min_size
)
img_padded = self.transform(img_padded)
pixel_values.append(img_padded)
pixel_masks.append(pixel_mask)

return BatchFeature(
data={
"pixel_values": torch.stack(pixel_values),
"pixel_mask": torch.stack(pixel_masks),
"num_crops": torch.stack(num_crops),
},
tensor_type=return_tensors,
)
Expand All @@ -177,10 +277,23 @@ def preprocess(
max_image_size=None,
min_image_size=None,
return_tensors: Optional[Union[str, TensorType]] = None,
split_image: Optional[bool] = False,
split_ratio: Optional[List[List[int]]] = [
[1, 1],
[1, 2],
[1, 3],
[1, 4],
[2, 2],
[2, 1],
[3, 1],
[4, 1],
],
):
return self.__call__(
images,
max_image_size=max_image_size,
min_image_size=min_image_size,
return_tensors=return_tensors,
split_image=split_image,
split_ratio=split_ratio,
)
Loading
Loading