diff --git a/src/r1_vlm/datasets/text_vqa/collect_demos.py b/src/r1_vlm/datasets/text_vqa/collect_demos.py new file mode 100644 index 0000000..23cad67 --- /dev/null +++ b/src/r1_vlm/datasets/text_vqa/collect_demos.py @@ -0,0 +1,284 @@ +import json +import os + +import matplotlib.pyplot as plt +import matplotlib.widgets as widgets +from datasets import load_dataset +from PIL import Image + +OUTPUT_FILE = "zoom_demos.jsonl" + + +# --- Image Resizing Logic (copied from text_vqa_tool_use_r1.py) --- +def resize_image(image): + """ + Resizes the image if its longer side exceeds 1024 pixels, + maintaining aspect ratio. + """ + # if the longer side is greater than 1024, resize it so the longer side is 1024 + if image.mode == "L": # Handle grayscale images + image = image.convert("RGB") + + if image.size[0] > 1024 or image.size[1] > 1024: + longer_side = max(image.size[0], image.size[1]) + image = image.resize( + ( + int(1024 * image.size[0] / longer_side), + int(1024 * image.size[1] / longer_side), + ), + Image.Resampling.LANCZOS, # Use LANCZOS for better quality + ) + return image + + +# --- Global State (simpler for this script) --- +current_example = None +original_image = None +original_size = None +displayed_image = None +displayed_size = None +selected_keypoint_original = None +selected_keypoint_display = None +marker_plot = None +dataset_iterator = None +demos_saved_count = 0 + +# --- Matplotlib UI Elements --- +fig, ax = plt.subplots(figsize=(10, 8)) # Make figure a bit larger +plt.subplots_adjust(bottom=0.25) # More space for controls and title +ax.axis("off") + +ax_save = plt.axes([0.81, 0.05, 0.15, 0.075]) +btn_save = widgets.Button(ax_save, "Save & Next") +btn_save.active = False # Initially disabled + +ax_undo = plt.axes([0.65, 0.05, 0.15, 0.075]) +btn_undo = widgets.Button(ax_undo, "Undo Click") +btn_undo.active = False # Initially disabled + +ax_skip = plt.axes([0.49, 0.05, 0.15, 0.075]) +btn_skip = widgets.Button(ax_skip, "Skip") + +# Text to show saved count +ax_count_text = plt.axes([0.05, 0.05, 0.4, 0.075]) +ax_count_text.axis("off") +count_text_obj = ax_count_text.text( + 0.5, + 0.5, + f"Saved: {demos_saved_count}", + ha="center", + va="center", + transform=ax_count_text.transAxes, +) + + +# --- Core Logic Functions --- +def update_saved_count_display(): + """Updates the saved count text on the plot.""" + global demos_saved_count + count_text_obj.set_text(f"Saved: {demos_saved_count}") + plt.draw() + + +def remove_marker(): + """Removes the keypoint marker from the plot if it exists.""" + global marker_plot + if marker_plot: + try: + marker_plot.remove() + except ValueError: # May have already been removed if plot cleared + pass + marker_plot = None + + +def display_current_example(): + """Displays the current image and question.""" + global displayed_image, displayed_size, marker_plot + if current_example is None: + return + + remove_marker() # Clear marker from previous example + + # Store original and prepare display image + original_image_pil = current_example["image"] + if original_image_pil.mode == "L": # Ensure RGB + original_image_pil = original_image_pil.convert("RGB") + + displayed_image = resize_image(original_image_pil.copy()) # Use resized for display + displayed_size = displayed_image.size + + ax.cla() # Clear axes before drawing new image + ax.imshow(displayed_image) + answers = current_example.get("answers", []) # Get answers, default to empty list + answers_str = ", ".join(answers) if answers else "N/A" + question_text = ( + f"Q ({current_example['question_id']}): {current_example['question']}\n" + f"Answers: {answers_str}" + ) + ax.set_title( + question_text, wrap=True, fontsize=9 + ) # Slightly smaller font for more text + ax.axis("off") + plt.draw() + + +def load_next_example(): + """Loads the next example from the dataset iterator and updates the UI.""" + global current_example, original_image, original_size + global selected_keypoint_original, selected_keypoint_display, marker_plot + global dataset_iterator + + if dataset_iterator is None: + print("Error: Dataset iterator not initialized.") + plt.close(fig) + return + + try: + current_example = next(dataset_iterator) + original_image = current_example["image"] + if original_image.mode == "L": # Ensure RGB for original too + original_image = original_image.convert("RGB") + original_size = original_image.size + + # Reset state for the new example + selected_keypoint_original = None + selected_keypoint_display = None + remove_marker() # Explicitly remove marker here too + + display_current_example() + + btn_save.active = False + btn_undo.active = False + plt.draw() # Update button states visually + + except StopIteration: + print("Finished processing all examples in the dataset split.") + plt.close(fig) + except Exception as e: + print(f"An unexpected error occurred loading the next example: {e}") + plt.close(fig) + + +def on_click(event): + """Handles mouse clicks on the image axes to select a keypoint.""" + global selected_keypoint_original, selected_keypoint_display, marker_plot + global original_size, displayed_size + + if event.inaxes != ax or original_size is None or displayed_size is None: + return # Click outside image axes or data not loaded + + # Get click coordinates relative to displayed image + x_disp, y_disp = int(event.xdata), int(event.ydata) + + # Map coordinates back to original image dimensions + scale_w = original_size[0] / displayed_size[0] + scale_h = original_size[1] / displayed_size[1] + x_orig = int(round(x_disp * scale_w)) + y_orig = int(round(y_disp * scale_h)) + + # Clamp coordinates to ensure they are within the original image bounds + x_orig = max(0, min(original_size[0] - 1, x_orig)) + y_orig = max(0, min(original_size[1] - 1, y_orig)) + + selected_keypoint_original = [x_orig, y_orig] + selected_keypoint_display = [x_disp, y_disp] + + # Update marker + remove_marker() + (marker_plot,) = ax.plot(x_disp, y_disp, "r+", markersize=12, markeredgewidth=2) + + # Enable Save and Undo buttons + btn_save.active = True + btn_undo.active = True + plt.draw() + + +def save_callback(event): + """Saves the current example's info and keypoint, then loads the next.""" + global demos_saved_count + if current_example is None or selected_keypoint_original is None: + print("Warning: Cannot save - no example loaded or keypoint selected.") + return + + data = { + "image_id": current_example["image_id"], + "question_id": current_example["question_id"], + "keypoint": selected_keypoint_original, + } + + try: + with open(OUTPUT_FILE, "a") as f: + f.write(json.dumps(data) + "\n") + print(f"Saved: QID {data['question_id']} -> {data['keypoint']}") + demos_saved_count += 1 + update_saved_count_display() + load_next_example() + except IOError as e: + print(f"Error saving data to {OUTPUT_FILE}: {e}") + except Exception as e: + print(f"An unexpected error occurred during save: {e}") + + +def undo_callback(event): + """Clears the selected keypoint and removes the marker.""" + global selected_keypoint_original, selected_keypoint_display + remove_marker() + selected_keypoint_original = None + selected_keypoint_display = None + btn_save.active = False + btn_undo.active = False + plt.draw() + + +def skip_callback(event): + """Loads the next example without saving.""" + print("Skipped.") + load_next_example() + + +# --- Initialization and Execution --- +def main(): + global dataset_iterator, demos_saved_count + + print("Loading TextVQA 'train' split...") + try: + dataset = load_dataset("lmms-lab/textvqa", split="train").shuffle(seed=42) + dataset_iterator = iter(dataset) + print("Dataset loaded.") + except Exception as e: + print(f"Failed to load dataset: {e}") + return + + # Count existing demos if file exists + if os.path.exists(OUTPUT_FILE): + try: + with open(OUTPUT_FILE, "r") as f: + demos_saved_count = sum(1 for line in f) + print( + f"Found {demos_saved_count} existing demonstrations in {OUTPUT_FILE}." + ) + except Exception as e: + print( + f"Warning: Could not read existing demo count from {OUTPUT_FILE}: {e}" + ) + + # Connect callbacks + btn_save.on_clicked(save_callback) + btn_undo.on_clicked(undo_callback) + btn_skip.on_clicked(skip_callback) + fig.canvas.mpl_connect("button_press_event", on_click) + + # Load the first example and show the plot + print("Loading first example...") + update_saved_count_display() # Show initial count + load_next_example() + + if current_example: # Only show plot if first load succeeded + print("Starting labeling interface. Close the plot window to exit.") + plt.show() + else: + print("Could not load the first example. Exiting.") + + +if __name__ == "__main__": + main() diff --git a/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py b/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py index d1b2408..7bb27f3 100644 --- a/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py +++ b/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py @@ -1,4 +1,7 @@ -from datasets import Dataset, DatasetDict, load_dataset +import json + +from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset +from PIL import Image from tqdm import tqdm from r1_vlm.datasets.utils import IMAGE_PLACEHOLDER @@ -12,7 +15,8 @@ def resize_image(image): ( int(1024 * image.size[0] / longer_side), int(1024 * image.size[1] / longer_side), - ) + ), + Image.Resampling.LANCZOS, ) return image @@ -79,7 +83,10 @@ def generate_r1_messages(example): def create_r1_text_vqa_tool_use_dataset( - max_examples_per_split: int | None = None, splits_to_process: list[str] = None + max_examples_per_split: int | None = None, + splits_to_process: list[str] = None, + zoom_demos_path: str + | None = "/millcreek/home/sunil/r1_vlm_bumbershoot2/r1_vlm/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl", ): dataset = load_dataset("lmms-lab/textvqa") @@ -91,6 +98,29 @@ def create_r1_text_vqa_tool_use_dataset( if split not in valid_splits: raise ValueError(f"Invalid split: {split}") + # --- Load Zoom Demos if path provided --- + zoom_demos_lookup = None + if zoom_demos_path and "train" in splits_to_process: + print(f"Loading zoom demonstrations from: {zoom_demos_path}") + try: + with open(zoom_demos_path, "r") as f: + demos = [json.loads(line) for line in f] + zoom_demos_lookup = { + demo["question_id"]: demo["keypoint"] for demo in demos + } + print(f"Loaded {len(zoom_demos_lookup)} zoom demonstrations.") + except FileNotFoundError: + print( + f"Warning: Zoom demos file not found at {zoom_demos_path}. Skipping augmentation." + ) + zoom_demos_lookup = None + except Exception as e: + print( + f"Warning: Error loading zoom demos from {zoom_demos_path}: {e}. Skipping augmentation." + ) + zoom_demos_lookup = None + # --- End Load Zoom Demos --- + processed_datasets = {} for split in splits_to_process: print(f"Processing {split} split...") @@ -103,13 +133,79 @@ def create_r1_text_vqa_tool_use_dataset( if len(examples) >= max_examples_per_split: break - processed_datasets[split] = Dataset.from_list(examples) + processed_dataset = Dataset.from_list(examples) + + if split == "train" and zoom_demos_lookup is not None: + print(f"Augmenting '{split}' split with zoom keypoints...") + + def add_zoom_keypoint(example): + keypoint = zoom_demos_lookup.get(example["question_id"], None) + return {"zoom_keypoint": keypoint} + + # Add the column + processed_dataset = processed_dataset.map( + add_zoom_keypoint + ) # Added num_proc for potential speedup + print(f"Added 'zoom_keypoint' column to '{split}' split.") + + # --- REORDERING START --- + print( + f"Reordering '{split}' split to place examples with keypoints first..." + ) + # Filter examples with and without keypoints + with_keypoint_ds = processed_dataset.filter( + lambda example: example["zoom_keypoint"] is not None + ) + without_keypoint_ds = processed_dataset.filter( + lambda example: example["zoom_keypoint"] is None + ) + + # --- Shuffle the 'without_keypoint_ds' --- + print("Shuffling examples without keypoints...") + without_keypoint_ds = without_keypoint_ds.shuffle( + seed=42 + ) # Added shuffle with seed + print("Shuffled examples without keypoints.") + # --- End Shuffle --- + + # Concatenate them in the desired order + processed_dataset = concatenate_datasets( + [with_keypoint_ds, without_keypoint_ds] + ) + print(f"Reordered '{split}' split.") + # --- REORDERING END --- + + processed_datasets[split] = processed_dataset return DatasetDict(processed_datasets) if __name__ == "__main__": dataset = create_r1_text_vqa_tool_use_dataset( - max_examples_per_split=10, splits_to_process=["train"] + max_examples_per_split=10000, splits_to_process=["train"] + ) + + # count how many examples have zoom_keypoint + num_with_keypoint = len( + [ + example + for example in dataset["train"] + if example["zoom_keypoint"] is not None + ] ) - print(dataset["train"][0]) + print(f"Number of examples with zoom_keypoint: {num_with_keypoint}") + + # Verify the first few examples have keypoints (if any exist) + print("\nVerifying first few examples have keypoints (if available):") + for i in range(min(5, num_with_keypoint)): + print( + f"Example {i}: QID={dataset['train'][i]['question_id']}, Keypoint={dataset['train'][i]['zoom_keypoint']}" + ) + + # Verify an example after the keypoint block has None (if applicable) + if num_with_keypoint < len(dataset["train"]): + print("\nVerifying example after the keypoint block:") + idx_after = num_with_keypoint # First example expected to have None + print( + f"Example {idx_after}: QID={dataset['train'][idx_after]['question_id']}, Keypoint={dataset['train'][idx_after]['zoom_keypoint']}" + ) diff --git a/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl b/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl new file mode 100644 index 0000000..cb2d4d0 --- /dev/null +++ b/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl @@ -0,0 +1,101 @@ +{"image_id": "0054c91397f2fe05", "question_id": 0, "keypoint": [352, 312]} +{"image_id": "005635e119b9f32f", "question_id": 1, "keypoint": [561, 322]} +{"image_id": "005635e119b9f32f", "question_id": 2, "keypoint": [80, 326]} +{"image_id": "00685bc495504d61", "question_id": 3, "keypoint": [119, 1010]} +{"image_id": "00685bc495504d61", "question_id": 4, "keypoint": [159, 567]} +{"image_id": "006d10667d17b924", "question_id": 5, "keypoint": [445, 461]} +{"image_id": "006d10667d17b924", "question_id": 6, "keypoint": [442, 140]} +{"image_id": "00c359f294f7dcd9", "question_id": 7, "keypoint": [373, 399]} +{"image_id": "00c359f294f7dcd9", "question_id": 8, "keypoint": [382, 400]} +{"image_id": "0122c5279a501df2", "question_id": 9, "keypoint": [907, 247]} +{"image_id": "0122c5279a501df2", "question_id": 10, "keypoint": [300, 275]} +{"image_id": "015c537f722d115e", "question_id": 11, "keypoint": [646, 362]} +{"image_id": "0202faf23a9aae11", "question_id": 12, "keypoint": [169, 601]} +{"image_id": "02b2daa20b00f5d3", "question_id": 13, "keypoint": [162, 195]} +{"image_id": "02b2daa20b00f5d3", "question_id": 14, "keypoint": [540, 322]} +{"image_id": "02da7a42e971e3cb", "question_id": 15, "keypoint": [520, 457]} +{"image_id": "02da7a42e971e3cb", "question_id": 16, "keypoint": [518, 460]} +{"image_id": "d709940b064b2ea7", "question_id": 4782, "keypoint": [525, 487]} +{"image_id": "d7e15efeee988b35", "question_id": 23953, "keypoint": [360, 631]} +{"image_id": "258dcb6be2766e6b", "question_id": 10485, "keypoint": [54, 296]} +{"image_id": "ae0e286ea6301919", "question_id": 6787, "keypoint": [854, 483]} +{"image_id": "a9085b21fafeb7e0", "question_id": 12304, "keypoint": [357, 519]} +{"image_id": "0876802af9ca4f53", "question_id": 2069, "keypoint": [295, 537]} +{"image_id": "51153c1aeeba2142", "question_id": 21859, "keypoint": [568, 533]} +{"image_id": "bb4f9f6fa8be3784", "question_id": 23499, "keypoint": [573, 303]} +{"image_id": "0cc86f2bc0023406", "question_id": 18672, "keypoint": [290, 534]} +{"image_id": "e47a56d412750c3e", "question_id": 14447, "keypoint": [383, 360]} +{"image_id": "3109e0ab52305483", "question_id": 5725, "keypoint": [336, 263]} +{"image_id": "65277ae6a3eb6c8a", "question_id": 14311, "keypoint": [478, 644]} +{"image_id": "3678e10d1ee15709", "question_id": 2707, "keypoint": [800, 190]} +{"image_id": "04f9cc3db7d8da24", "question_id": 1303, "keypoint": [646, 456]} +{"image_id": "99c2dd0295e97a04", "question_id": 30730, "keypoint": [92, 417]} +{"image_id": "fcab06f91d7da18d", "question_id": 33237, "keypoint": [414, 591]} +{"image_id": "61a5eb30c602a678", "question_id": 3308, "keypoint": [461, 508]} +{"image_id": "3bc293efa0203314", "question_id": 18512, "keypoint": [370, 471]} +{"image_id": "4a8830e36c78f5e7", "question_id": 14061, "keypoint": [553, 267]} +{"image_id": "beed9557550278a3", "question_id": 9481, "keypoint": [545, 327]} +{"image_id": "4c858f86738a72fd", "question_id": 1534, "keypoint": [666, 348]} +{"image_id": "104bf2461f9962a8", "question_id": 5445, "keypoint": [452, 532]} +{"image_id": "5b41dbe96cd05104", "question_id": 26950, "keypoint": [689, 363]} +{"image_id": "8dad2e89e4ad9839", "question_id": 6536, "keypoint": [58, 189]} +{"image_id": "ad899b711c4e5613", "question_id": 6784, "keypoint": [119, 406]} +{"image_id": "0927271c7f9f3b9e", "question_id": 30035, "keypoint": [153, 324]} +{"image_id": "cbe00b9c28494af5", "question_id": 17311, "keypoint": [444, 538]} +{"image_id": "2d4f3317cfe522c8", "question_id": 21210, "keypoint": [438, 598]} +{"image_id": "8568bc5fc7abbb33", "question_id": 16729, "keypoint": [611, 114]} +{"image_id": "a70b30b1806e1226", "question_id": 23157, "keypoint": [732, 293]} +{"image_id": "008b8556a73601ae", "question_id": 17883, "keypoint": [384, 830]} +{"image_id": "1b8bf6437a221a10", "question_id": 1412, "keypoint": [461, 222]} +{"image_id": "4f8d13bce21b6ada", "question_id": 34364, "keypoint": [397, 359]} +{"image_id": "1a47acd8e0d07441", "question_id": 28380, "keypoint": [540, 258]} +{"image_id": "143d9adc0368b1dd", "question_id": 17992, "keypoint": [508, 249]} +{"image_id": "bba917b91c4c46dd", "question_id": 12361, "keypoint": [961, 402]} +{"image_id": "0204574848b0bf05", "question_id": 29549, "keypoint": [363, 521]} +{"image_id": "083ea2d88ee71e24", "question_id": 31564, "keypoint": [211, 763]} +{"image_id": "2f6621dfaa4e4672", "question_id": 27459, "keypoint": [745, 452]} +{"image_id": "d929fd569dbae5fe", "question_id": 4809, "keypoint": [516, 882]} +{"image_id": "20bb636326491022", "question_id": 20971, "keypoint": [503, 63]} +{"image_id": "8df181c0899df120", "question_id": 3843, "keypoint": [446, 204]} +{"image_id": "923545f7c0785426", "question_id": 34026, "keypoint": [505, 864]} +{"image_id": "89f64c747cf016ec", "question_id": 29772, "keypoint": [531, 543]} +{"image_id": "ba0c5d07cd0803fa", "question_id": 25480, "keypoint": [495, 92]} +{"image_id": "6ce0bbbee320e928", "question_id": 12134, "keypoint": [134, 686]} +{"image_id": "a760819a4cd02bc5", "question_id": 32792, "keypoint": [479, 661]} +{"image_id": "04b99754c51e6092", "question_id": 1296, "keypoint": [219, 377]} +{"image_id": "061d8e57615d98c4", "question_id": 12598, "keypoint": [140, 536]} +{"image_id": "f3d9c7c15556813a", "question_id": 24298, "keypoint": [890, 89]} +{"image_id": "023463e93c2fa36e", "question_id": 24935, "keypoint": [631, 39]} +{"image_id": "6b1bad9e14c54eff", "question_id": 13167, "keypoint": [242, 498]} +{"image_id": "14297a09cd0704df", "question_id": 2226, "keypoint": [444, 480]} +{"image_id": "9f9e1b7cdf8cc9d5", "question_id": 25407, "keypoint": [401, 49]} +{"image_id": "0012836b7cb7e1a8", "question_id": 29531, "keypoint": [822, 462]} +{"image_id": "2e975f0231b49855", "question_id": 28647, "keypoint": [522, 286]} +{"image_id": "ac32d325dbe4cf76", "question_id": 17033, "keypoint": [674, 28]} +{"image_id": "a6628a4b35d3382b", "question_id": 6722, "keypoint": [571, 426]} +{"image_id": "8c48baf8904eeca4", "question_id": 18305, "keypoint": [413, 333]} +{"image_id": "2124bb3f0620f8b0", "question_id": 11024, "keypoint": [578, 460]} +{"image_id": "5091e1f3b5a27a0f", "question_id": 14142, "keypoint": [392, 342]} +{"image_id": "eacf4d74fad54bac", "question_id": 31014, "keypoint": [883, 279]} +{"image_id": "3c570608f258bcfd", "question_id": 21503, "keypoint": [394, 793]} +{"image_id": "cb2aa50349915d64", "question_id": 23771, "keypoint": [325, 147]} +{"image_id": "92d76e2fdaf6f07d", "question_id": 3918, "keypoint": [639, 914]} +{"image_id": "06bc833b6794cb6d", "question_id": 18174, "keypoint": [399, 407]} +{"image_id": "0fc054810d8ab304", "question_id": 8424, "keypoint": [274, 37]} +{"image_id": "423165577f3f52ce", "question_id": 30439, "keypoint": [575, 348]} +{"image_id": "12bc08d2fce21a45", "question_id": 10451, "keypoint": [919, 286]} +{"image_id": "e101f9601e0d971c", "question_id": 667, "keypoint": [89, 416]} +{"image_id": "5894cb4708630e4c", "question_id": 27527, "keypoint": [202, 624]} +{"image_id": "df824b1152781395", "question_id": 20093, "keypoint": [862, 995]} +{"image_id": "fecd1e75b8deeca0", "question_id": 17756, "keypoint": [182, 183]} +{"image_id": "d81bf036569df61d", "question_id": 8999, "keypoint": [743, 125]} +{"image_id": "813f257ac2f2bfe2", "question_id": 25334, "keypoint": [141, 724]} +{"image_id": "dc9d9de141534563", "question_id": 27148, "keypoint": [494, 372]} +{"image_id": "56853543d7ae1235", "question_id": 13047, "keypoint": [362, 77]} +{"image_id": "0919842b32156e39", "question_id": 9127, "keypoint": [650, 264]} +{"image_id": "7112b28a96993e28", "question_id": 9994, "keypoint": [218, 677]} +{"image_id": "5ddd2a8a3c2328d6", "question_id": 19626, "keypoint": [693, 457]} +{"image_id": "765bb8ab1e6a0eff", "question_id": 15176, "keypoint": [536, 93]} +{"image_id": "4b7044e659d47df3", "question_id": 31928, "keypoint": [322, 107]} +{"image_id": "476198cff16798d6", "question_id": 25176, "keypoint": [460, 420]} +{"image_id": "e0b856b9fe929346", "question_id": 4906, "keypoint": [433, 561]} diff --git a/src/r1_vlm/environments/multistep_vision_env.py b/src/r1_vlm/environments/multistep_vision_env.py index 825cf7a..9cc6a72 100644 --- a/src/r1_vlm/environments/multistep_vision_env.py +++ b/src/r1_vlm/environments/multistep_vision_env.py @@ -218,7 +218,9 @@ def clean_messages_for_logging(messages): ) for image in images: try: - imgcat.imgcat(image) + image_copy = image.copy() + image_copy.convert("RGB") + imgcat.imgcat(image_copy) except Exception as e: print( f"Caught failed imgcat call for image. As this is just a debugging print, we will not raise an error. {e}" diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py index 344947f..d041d94 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py @@ -1,3 +1,4 @@ +import math import re from typing import Any, Callable @@ -292,8 +293,17 @@ def get_dataset( output_datasets[split] = dataset_split + # only keep the examples with keypoints from the train set for initial finetuning + # if "train" in splits: + # print("Filtering train set to only include examples with zoom keypoints.") + # output_datasets["train"] = output_datasets["train"].filter( + # lambda x: x["zoom_keypoint"] is not None + # ) + # print(f"After filtering, {len(output_datasets['train'])} examples remain.") + + # shuffle the train set if "train" in splits: - output_datasets["train"].shuffle() + output_datasets["train"] = output_datasets["train"].shuffle() return output_datasets @@ -341,6 +351,11 @@ def get_reward_weights(self) -> list[float]: # consistent high reward for getting the answer right schedule = 1.0 reward_weights.append(schedule) + + elif reward_function.__name__ == "zoom_keypoint_reward_func": + # consistent high reward for using the zoom keypoint + schedule = 1.0 + reward_weights.append(schedule) else: raise ValueError( f"Unknown reward function: {reward_function.__name__} encountered in get_reward_weights" @@ -548,10 +563,104 @@ def correct_answer_reward_func( return correctness_rewards + def extract_keypoint(text): + # Pattern to match [x, y] where x and y are integers, allowing multiline + pattern = r"\[\s*(\d+)\s*,\s*(\d+)\s*\]" + match = re.search(pattern, text, re.DOTALL) + + if match: + x = int(match.group(1)) + y = int(match.group(2)) + return x, y + return None + + def zoom_keypoint_reward_func( + prompts, completions, completions_messages, **kwargs + ): + merged_completion_conversations = MultistepVisionEnv.preprocess_messages( + prompts_messages=prompts, completions_messages=completions_messages + ) + + # extract the image size from the prompt + prompt = prompts[0] + image_content = prompt[1]["content"][1] + image = image_content["image"] + image_width = image.width + image_height = image.height + + # extract the zoom keypoint from the completion, if it exists + conv_keypoints = [] + for conv in merged_completion_conversations: + keypoint_found = False + for msg in conv: + if msg["role"] == "assistant": + parsed = self.parser.parse(msg["content"][0]["text"]) + if hasattr(parsed, "tool") and parsed.tool is not None: + tool_content = parsed.tool + if "name: zoom" in tool_content: + conv_keypoint = extract_keypoint(tool_content) + conv_keypoints.append(conv_keypoint) + keypoint_found = True + break + # if no tool call with a zoom keypoint is found, append None + if not keypoint_found: + conv_keypoints.append(None) + + print(f"conv_keypoints: {conv_keypoints}") + + # check if the example has a zoom keypoint - this is a list of Group size tuples, each tuple is identical + zoom_keypoints = kwargs.get("zoom_keypoint", None) + + # if this example does not have a zoom keypoint, return 0 reward for all completions as there's no way to verify the result + if zoom_keypoints is None: + return [0.0] * len(merged_completion_conversations) + + zoom_keypoint = zoom_keypoints[0] + + # if the zoom keypoint is not present, return 0 reward for all completions as there's no way to verify the result + if zoom_keypoint is None: + return [0.0] * len(merged_completion_conversations) + + normalized_zoom_keypoint = ( + zoom_keypoint[0] / image_width, + zoom_keypoint[1] / image_height, + ) + print(f"normalized_zoom_keypoint: {normalized_zoom_keypoint}") + + normalized_conv_keypoints = [] + for conv_keypoint in conv_keypoints: + if conv_keypoint is None: + normalized_conv_keypoints.append(None) + else: + normalized_conv_keypoints.append( + ( + conv_keypoint[0] / image_width, + conv_keypoint[1] / image_height, + ) + ) + + rewards = [] + for normalized_conv_keypoint in normalized_conv_keypoints: + if normalized_conv_keypoint is None: + rewards.append(0.0) + else: + # compute the normalized euclidean distance between the keypoint in the tool call and the zoom keypoint + distance = math.sqrt( + (normalized_conv_keypoint[0] - normalized_zoom_keypoint[0]) ** 2 + + (normalized_conv_keypoint[1] - normalized_zoom_keypoint[1]) + ** 2 + ) + # clip the reward at 0.0, as this value could be negative (1 - root 2) worst case + reward = max(1 - distance, 0.0) + rewards.append(reward) + + return rewards + return [ format_reward_func, tool_execution_reward_func, correct_answer_reward_func, + zoom_keypoint_reward_func, ] def log_metrics(self, conversations, completions_text, completion_messages): @@ -560,10 +669,16 @@ def log_metrics(self, conversations, completions_text, completion_messages): completions_with_tool_use = 0 completions_with_zoom_use = 0 + use_horizontal_zoom = 0 + use_vertical_zoom = 0 + use_square_zoom = 0 for completion in completions_text: tool_use_regex = r"(.*?)" zoom_use_string = "name: zoom" + horizontal_zoom_use_string = "aspect_ratio_mode: horizontal" + vertical_zoom_use_string = "aspect_ratio_mode: vertical" + square_zoom_use_string = "aspect_ratio_mode: square" tool_matches = re.findall(tool_use_regex, completion, re.DOTALL) if tool_matches: completions_with_tool_use += 1 @@ -571,8 +686,15 @@ def log_metrics(self, conversations, completions_text, completion_messages): if zoom_use_string in tool_content: completions_with_zoom_use += 1 + if horizontal_zoom_use_string in tool_content: + use_horizontal_zoom += 1 + if vertical_zoom_use_string in tool_content: + use_vertical_zoom += 1 + if square_zoom_use_string in tool_content: + use_square_zoom += 1 + print( - f"There are {len(completions_text)} completions, {completions_with_tool_use} of which attempt to use a tool, {completions_with_zoom_use} of which attempt to use zoom" + f"There are {len(completions_text)} completions, \n {completions_with_tool_use} of which attempt to use a tool, \n {completions_with_zoom_use} of which attempt to use zoom. \n {use_horizontal_zoom} of which attempt to use horizontal zoom, \n {use_vertical_zoom} of which attempt to use vertical zoom, \n {use_square_zoom} of which attempt to use square zoom" ) num_completions = len(completions_text) @@ -582,6 +704,9 @@ def log_metrics(self, conversations, completions_text, completion_messages): return { "tool_use_proportion": tool_use_proportion, "zoom_use_proportion": zoom_use_proportion, + "use_horizontal_zoom": use_horizontal_zoom, + "use_vertical_zoom": use_vertical_zoom, + "use_square_zoom": use_square_zoom, } diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py index acd678d..51bb215 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py @@ -90,8 +90,12 @@ def find_target_linear_names( def train(): + checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-text-vqa-distance-reward-zoom-aspect-ratio-may4-3B/checkpoint-500" + model, peft_config, processor, model_config, gradient_checkpointing = ( - load_model_and_processor(gradient_checkpointing=True, use_peft=False) + load_model_and_processor( + model_name_or_path=checkpoint, gradient_checkpointing=True, use_peft=False + ) ) print("loaded model") @@ -110,7 +114,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-text-vqa-tool-use-no-tool-reward-may4-3B", + output_dir="vlm-r1-text-vqa-post-demo-ft-may5-3B", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, @@ -120,7 +124,7 @@ def train(): logging_steps=1, save_steps=50, save_total_limit=10, - num_train_epochs=1, + num_train_epochs=100, per_device_train_batch_size=1, num_generations=3, gradient_accumulation_steps=4, @@ -158,6 +162,8 @@ def train(): env=vf_env, peft_config=peft_config, guided_regex=None, + # shuffling is disabled because we want to start with the keypoint examples. + shuffle_dataset=False, ) trainer.train() diff --git a/src/r1_vlm/tools/zoom.py b/src/r1_vlm/tools/zoom.py index 534481d..74809dd 100644 --- a/src/r1_vlm/tools/zoom.py +++ b/src/r1_vlm/tools/zoom.py @@ -1,5 +1,6 @@ # generic zoom tool import json +import math # Added import from typing import List, Tuple import pytest @@ -9,31 +10,33 @@ def calculate_crop_coordinates( - keypoint: List[int], image_size: Tuple[int, int], target_size: int = 300 + keypoint: List[int], + image_size: Tuple[int, int], + aspect_ratio_mode: str, + base_target_size: int = 300, ) -> Tuple[int, int, int, int]: """ - Calculates the crop coordinates for a square box centered around a keypoint. + Calculates the crop coordinates for a box centered around a keypoint, + with a specified aspect ratio. + The box aims for an area roughly equal to base_target_size * base_target_size. If the target box extends beyond the image boundaries, the box is shifted - to stay within the image, maintaining the target size. The keypoint will - no longer be centered in this case. - - The returned coordinates are suitable for use with PIL's Image.crop method, - following the (left, upper, right, lower) format where 'right' and 'lower' - are exclusive. + to stay within the image, maintaining the target size and aspect ratio. + The keypoint will no longer be centered in this case. Args: keypoint: A list or tuple [x, y] representing the desired center coordinates. image_size: A tuple (width, height) of the original image. - target_size: The desired width and height of the square crop box. + aspect_ratio_mode: The desired aspect ratio ("square", "horizontal", "vertical"). + base_target_size: The base dimension used for area calculation. Returns: A tuple (crop_left, crop_upper, crop_right, crop_lower) defining the region to crop from the original image. Raises: - ValueError: If the keypoint is outside the image boundaries or inputs - are invalid types/formats. + ValueError: If inputs are invalid types/formats, keypoint is outside + the image boundaries, or aspect_ratio_mode is invalid. """ # --- Input Validation --- if ( @@ -54,8 +57,13 @@ def calculate_crop_coordinates( "Error:image_size must be a tuple of two positive integers (width, height)" ) - if not isinstance(target_size, int) or target_size <= 0: - raise ValueError("Error: target_size must be a positive integer") + if not isinstance(base_target_size, int) or base_target_size <= 0: + raise ValueError("Error: base_target_size must be a positive integer") + + if aspect_ratio_mode not in ["square", "horizontal", "vertical"]: + raise ValueError( + "Error: aspect_ratio_mode must be 'square', 'horizontal', or 'vertical'" + ) x, y = keypoint width, height = image_size @@ -66,22 +74,40 @@ def calculate_crop_coordinates( f"(width={width}, height={height})" ) + # --- Calculate Target Crop Dimensions based on Aspect Ratio --- + sqrt2 = math.sqrt(2) + if aspect_ratio_mode == "square": + target_crop_width = base_target_size + target_crop_height = base_target_size + elif aspect_ratio_mode == "horizontal": # 2:1 + target_crop_width = round(base_target_size * sqrt2) + target_crop_height = round(base_target_size / sqrt2) + else: # aspect_ratio_mode == "vertical" (1:2) + target_crop_width = round(base_target_size / sqrt2) + target_crop_height = round(base_target_size * sqrt2) + + # Ensure dimensions are at least 1 + target_crop_width = max(1, target_crop_width) + target_crop_height = max(1, target_crop_height) + # --- Calculate Crop Box --- - half_size = target_size // 2 + # Calculate half-sizes (integer division) + half_width = target_crop_width // 2 + half_height = target_crop_height // 2 # Calculate the ideal top-left corner to center the box - ideal_x_min = x - half_size - ideal_y_min = y - half_size + ideal_x_min = x - half_width + ideal_y_min = y - half_height # Initialize the actual top-left corner crop_left = ideal_x_min crop_upper = ideal_y_min # Adjust if the box extends beyond the right or bottom boundaries - if crop_left + target_size > width: - crop_left = width - target_size - if crop_upper + target_size > height: - crop_upper = height - target_size + if crop_left + target_crop_width > width: + crop_left = width - target_crop_width + if crop_upper + target_crop_height > height: + crop_upper = height - target_crop_height # Adjust if the box extends beyond the left or top boundaries if crop_left < 0: @@ -91,13 +117,11 @@ def calculate_crop_coordinates( # Calculate the final right and lower bounds for PIL crop (exclusive) # Clamp the right/lower bounds to the image dimensions, important if - # target_size > image width/height. - crop_right = min(crop_left + target_size, width) - crop_lower = min(crop_upper + target_size, height) + # target crop dimensions > image width/height. + crop_right = min(crop_left + target_crop_width, width) + crop_lower = min(crop_upper + target_crop_height, height) # Ensure left/upper are also clamped in case target_size > image dimensions - # which might lead to negative values after boundary adjustments. - # The previous checks already handle this, but being explicit doesn't hurt. crop_left = max(0, crop_left) crop_upper = max(0, crop_upper) @@ -107,29 +131,48 @@ def calculate_crop_coordinates( def zoom( image_name: str, keypoint: list[int], + aspect_ratio_mode: str, **kwargs, ) -> Image.Image: """ - Returns an image zoomed in on the specified keypoint. + Returns an image zoomed in on the specified keypoint with a given aspect ratio. This is useful to see a region of interest with higher clarity, like for reading text or identifying objects. + The choice of aspect ration is helpful depending on the task. + If you are looking at horizontal text or some other wide region, you might use "horizontal" aspect ratio. + If you are looking at vertical text or some other tall region, you might use "vertical" aspect ratio. + If you are looking at a region more generally, you might use "square" aspect ratio. + Args: image_name: str, the name of the image to zoom in on. Can only be called on the "input_image". keypoint: list[int], the [x, y] coordinates of the point to center the zoom on. Must be within the image boundaries. + aspect_ratio_mode: str, the desired aspect ratio for the zoom window. Must be one of "square", "horizontal", "vertical". + "square" is 1:1, "horizontal" is 2:1 (wider), "vertical" is 1:2 (taller). Returns: - The image zoomed in on the specified keypoint. + The image zoomed in on the specified keypoint with the requested aspect ratio. Examples: name: zoom image_name: input_image + aspect_ratio_mode: square keypoint: [500, 400] + + + + name: zoom + image_name: input_image + aspect_ratio_mode: horizontal + keypoint: [23, 44] + name: zoom image_name: input_image - keypoint: [50, 40] + aspect_ratio_mode: vertical + keypoint: [723, 461] + """ # get and validate the image @@ -146,27 +189,50 @@ def zoom( f"Error: Image {image_name} is not the input_image. This tool can only be called on the input_image." ) - width, height = image.size - # size we will crop to - target_size = 300 - - # we will resize the image to the larger dimension of the input image, or 400 if it's a smaller image - resize_size = max(width, height, 400) + original_width, original_height = image.size + # base size for cropping area calculation + base_crop_target_size = 300 - # Validate keypoint and calculate crop box using the helper function - # ValueError will be raised by the helper if keypoint is invalid - crop_box = calculate_crop_coordinates(keypoint, (width, height), target_size) + # Validate aspect_ratio_mode and keypoint, then calculate crop box + crop_box = calculate_crop_coordinates( + keypoint, + (original_width, original_height), + aspect_ratio_mode, + base_target_size=base_crop_target_size, + ) # crop the image to the calculated box cropped_image = image.crop(crop_box) - - # Resize the cropped image to the target size - output_image = cropped_image.resize( - (resize_size, resize_size), Image.Resampling.LANCZOS + crop_w, crop_h = cropped_image.size + + # If crop resulted in zero size (e.g., invalid crop box calculation edge case) + if crop_w == 0 or crop_h == 0: + raise ValueError("Error: Cropped image has zero dimension") + + # Determine the base dimension for resizing the output - use original image dims or 400 min + base_resize_dim = max(original_width, original_height, 400) + + # Calculate the final output size, preserving aspect ratio of the cropped image + # Scale so the LARGER dimension of the output matches base_resize_dim + if crop_w >= crop_h: # Wider or square crop + output_w = base_resize_dim + output_h = round(output_w * crop_h / crop_w) # Maintain aspect ratio + else: # Taller crop + output_h = base_resize_dim + output_w = round(output_h * crop_w / crop_h) # Maintain aspect ratio + + # Ensure minimum dimension is 1 + output_w = max(1, output_w) + output_h = max(1, output_h) + output_size = (output_w, output_h) + + # Resize the cropped image to the calculated output size, preserving aspect ratio + output_image = cropped_image.resize(output_size, Image.Resampling.LANCZOS) + + print( + f"Zoom tool cropped to {crop_w}x{crop_h}, resized output to: {output_image.size}" ) - print(f"Zoom tool output image size: {output_image.size}") - return output_image @@ -174,7 +240,7 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs: """ Parses raw string arguments for the zoom tool (keypoint version). - Expects keys: 'name', 'image_name', 'keypoint'. + Expects keys: 'name', 'image_name', 'keypoint', 'aspect_ratio_mode'. Converts 'keypoint' from a JSON string to a list of integers. Detailed validation of values (e.g., keypoint coordinates) is deferred to the zoom function itself. @@ -184,14 +250,20 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs: Returns: A dictionary containing the arguments with basic type conversions applied, - ready for the zoom function. Keys: 'image_name', 'keypoint'. + ready for the zoom function. Keys: 'image_name', 'keypoint', 'aspect_ratio_mode'. Raises: ValueError: If required keys are missing, extra keys are present, - or basic type conversion fails (e.g., 'keypoint' is not valid JSON - or doesn't result in a list of two integers). + basic type conversion fails (e.g., 'keypoint' is not valid JSON + or doesn't result in a list of two integers), or + 'aspect_ratio_mode' has an invalid value. """ - required_keys = {"name", "image_name", "keypoint"} + required_keys = { + "name", + "image_name", + "keypoint", + "aspect_ratio_mode", + } # Added aspect_ratio_mode actual_keys = set(raw_args.keys()) # 1. Check for Missing Keys @@ -224,9 +296,12 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs: raise ValueError( "Error: Both elements in 'keypoint' list must be integers." ) - typed_args["keypoint"] = keypoint_list + # Validate aspect_ratio_mode + aspect_mode = raw_args["aspect_ratio_mode"] + typed_args["aspect_ratio_mode"] = aspect_mode + except json.JSONDecodeError: raise ValueError( f"Error: Invalid JSON format for 'keypoint': '{raw_args['keypoint']}'" @@ -254,7 +329,7 @@ def test_calc_centered(sample_image): img_size = sample_image["input_image"].size kp = [500, 400] ts = 300 - box = calculate_crop_coordinates(kp, img_size, ts) + box = calculate_crop_coordinates(kp, img_size, "square", ts) assert box == (350, 250, 650, 550) @@ -262,7 +337,7 @@ def test_calc_top_left(sample_image): img_size = sample_image["input_image"].size kp = [50, 40] ts = 300 - box = calculate_crop_coordinates(kp, img_size, ts) + box = calculate_crop_coordinates(kp, img_size, "square", ts) assert box == (0, 0, 300, 300) @@ -270,7 +345,7 @@ def test_calc_bottom_right(sample_image): img_size = sample_image["input_image"].size kp = [950, 750] ts = 300 - box = calculate_crop_coordinates(kp, img_size, ts) + box = calculate_crop_coordinates(kp, img_size, "square", ts) assert box == (700, 500, 1000, 800) @@ -278,7 +353,7 @@ def test_calc_right_middle(sample_image): img_size = sample_image["input_image"].size kp = [950, 400] ts = 300 - box = calculate_crop_coordinates(kp, img_size, ts) + box = calculate_crop_coordinates(kp, img_size, "square", ts) assert box == (700, 250, 1000, 550) @@ -286,7 +361,7 @@ def test_calc_small_image(sample_image): small_img_size = (200, 150) kp = [100, 75] ts = 300 - box = calculate_crop_coordinates(kp, small_img_size, ts) + box = calculate_crop_coordinates(kp, small_img_size, "square", ts) assert box == (0, 0, 200, 150) @@ -294,30 +369,30 @@ def test_calc_invalid_keypoint_coords(sample_image): img_size = sample_image["input_image"].size ts = 300 with pytest.raises(ValueError, match="outside the image boundaries"): - calculate_crop_coordinates([1000, 400], img_size, ts) + calculate_crop_coordinates([1000, 400], img_size, "square", ts) with pytest.raises(ValueError, match="outside the image boundaries"): - calculate_crop_coordinates([-1, 400], img_size, ts) + calculate_crop_coordinates([-1, 400], img_size, "square", ts) with pytest.raises(ValueError, match="outside the image boundaries"): - calculate_crop_coordinates([500, 800], img_size, ts) + calculate_crop_coordinates([500, 800], img_size, "square", ts) with pytest.raises(ValueError, match="outside the image boundaries"): - calculate_crop_coordinates([500, -1], img_size, ts) + calculate_crop_coordinates([500, -1], img_size, "square", ts) def test_calc_invalid_input_types(): with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - calculate_crop_coordinates("[100, 100]", (500, 500), 300) + calculate_crop_coordinates("[100, 100]", (500, 500), "square", 300) with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - calculate_crop_coordinates([100], (500, 500), 300) + calculate_crop_coordinates([100], (500, 500), "square", 300) with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - calculate_crop_coordinates([100.0, 100], (500, 500), 300) + calculate_crop_coordinates([100.0, 100], (500, 500), "square", 300) with pytest.raises(ValueError, match="image_size must be a tuple"): - calculate_crop_coordinates([100, 100], [500, 500], 300) + calculate_crop_coordinates([100, 100], [500, 500], "square", 300) with pytest.raises(ValueError, match="image_size must be a tuple"): - calculate_crop_coordinates([100, 100], (500, 0), 300) + calculate_crop_coordinates([100, 100], (500, 0), "square", 300) with pytest.raises(ValueError, match="target_size must be a positive integer"): - calculate_crop_coordinates([100, 100], (500, 500), 0) + calculate_crop_coordinates([100, 100], (500, 500), "square", 0) with pytest.raises(ValueError, match="target_size must be a positive integer"): - calculate_crop_coordinates([100, 100], (500, 500), -100) + calculate_crop_coordinates([100, 100], (500, 500), "square", -100) # --- Tests for zoom function --- @@ -325,7 +400,7 @@ def test_calc_invalid_input_types(): def test_zoom_invalid_image_name(sample_image): with pytest.raises(ValueError, match="not found"): - zoom("nonexistent_image", [100, 100], images=sample_image) + zoom("nonexistent_image", [100, 100], "square", images=sample_image) # Add test for incorrect image name usage (e.g., "test_image" instead of "input_image") @@ -334,6 +409,7 @@ def test_zoom_incorrect_image_name_usage(sample_image): zoom( "test_image", [100, 100], + "square", images={"test_image": sample_image["input_image"]}, ) @@ -342,44 +418,50 @@ def test_zoom_incorrect_image_name_usage(sample_image): def test_zoom_invalid_keypoint_format(sample_image): with pytest.raises(ValueError, match="keypoint must be a list or tuple"): zoom( - "input_image", "[100, 100]", images=sample_image + "input_image", "[100, 100]", "square", images=sample_image ) # Pass string instead of list with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - zoom("input_image", [100], images=sample_image) # Pass list with 1 element + zoom( + "input_image", [100], "square", images=sample_image + ) # Pass list with 1 element with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - zoom("input_image", [100.0, 100], images=sample_image) # Pass float + zoom("input_image", [100.0, 100], "square", images=sample_image) # Pass float # Test keypoint outside image bounds (should be caught by calculate_crop_coordinates) def test_zoom_keypoint_out_of_bounds(sample_image): img_size = sample_image["input_image"].size with pytest.raises(ValueError, match="outside the image boundaries"): - zoom("input_image", [img_size[0], 100], images=sample_image) # x = width + zoom( + "input_image", [img_size[0], 100], "square", images=sample_image + ) # x = width with pytest.raises(ValueError, match="outside the image boundaries"): - zoom("input_image", [-1, 100], images=sample_image) # x < 0 + zoom("input_image", [-1, 100], "square", images=sample_image) # x < 0 with pytest.raises(ValueError, match="outside the image boundaries"): - zoom("input_image", [100, img_size[1]], images=sample_image) # y = height + zoom( + "input_image", [100, img_size[1]], "square", images=sample_image + ) # y = height with pytest.raises(ValueError, match="outside the image boundaries"): - zoom("input_image", [100, -1], images=sample_image) # y < 0 + zoom("input_image", [100, -1], "square", images=sample_image) # y < 0 def test_zoom_basic_centered(sample_image): keypoint = [500, 400] # Center of 1000x800 image - result = zoom("input_image", keypoint, images=sample_image) + result = zoom("input_image", keypoint, "square", images=sample_image) assert isinstance(result, Image.Image) assert result.size == (300, 300) # Always 300x300 output def test_zoom_edge_case_top_left(sample_image): keypoint = [50, 40] # Near top-left - result = zoom("input_image", keypoint, images=sample_image) + result = zoom("input_image", keypoint, "square", images=sample_image) assert isinstance(result, Image.Image) assert result.size == (300, 300) def test_zoom_edge_case_bottom_right(sample_image): keypoint = [950, 750] # Near bottom-right - result = zoom("input_image", keypoint, images=sample_image) + result = zoom("input_image", keypoint, "square", images=sample_image) assert isinstance(result, Image.Image) assert result.size == (300, 300) @@ -389,7 +471,7 @@ def test_zoom_small_image(sample_image): small_img = Image.new("RGB", (200, 150)) small_sample = {"input_image": small_img} keypoint = [100, 75] # Center of small image - result = zoom("input_image", keypoint, images=small_sample) + result = zoom("input_image", keypoint, "square", images=small_sample) assert isinstance(result, Image.Image) assert result.size == (300, 300) # Output is still 300x300 after resize @@ -404,7 +486,7 @@ def test_zoom_content_check(sample_image): filled_sample = {"input_image": img} keypoint = [500, 400] # Center point within the red box - result = zoom("input_image", keypoint, images=filled_sample) + result = zoom("input_image", keypoint, "square", images=filled_sample) assert result.size == (300, 300) # Check a central pixel in the output (should correspond to the red box) # The original keypoint [500, 400] should map to the center [150, 150] in the 300x300 output @@ -416,14 +498,25 @@ def test_zoom_content_check(sample_image): def test_parse_valid_args(): - raw = {"name": "zoom", "image_name": "input_image", "keypoint": "[150, 250]"} - expected = {"image_name": "input_image", "keypoint": [150, 250]} + raw = { + "name": "zoom", + "image_name": "input_image", + "keypoint": "[150, 250]", + "aspect_ratio_mode": "square", + } + expected = { + "image_name": "input_image", + "keypoint": [150, 250], + "aspect_ratio_mode": "square", + } assert parse_zoom_args(raw) == expected def test_parse_missing_key(): - raw = {"name": "zoom", "image_name": "input_image"} - with pytest.raises(ValueError, match="Missing required arguments.*keypoint"): + raw = {"name": "zoom", "image_name": "input_image", "keypoint": "[150, 250]"} + with pytest.raises( + ValueError, match="Missing required arguments.*aspect_ratio_mode" + ): parse_zoom_args(raw) @@ -432,6 +525,7 @@ def test_parse_extra_key(): "name": "zoom", "image_name": "input_image", "keypoint": "[100, 100]", + "aspect_ratio_mode": "square", "extra": "bad", } with pytest.raises(ValueError, match="Unexpected arguments.*extra"): @@ -443,6 +537,7 @@ def test_parse_invalid_json(): "name": "zoom", "image_name": "input_image", "keypoint": "[100, 100", + "aspect_ratio_mode": "square", } # Missing closing bracket with pytest.raises(ValueError, match="Invalid JSON format for 'keypoint'"): parse_zoom_args(raw) @@ -453,6 +548,7 @@ def test_parse_keypoint_not_list(): "name": "zoom", "image_name": "input_image", "keypoint": '{"x": 100, "y": 100}', + "aspect_ratio_mode": "square", } with pytest.raises(ValueError, match="'keypoint' must be a JSON list"): parse_zoom_args(raw) @@ -503,18 +599,18 @@ def test_parse_keypoint_wrong_type(): try: print(f"\nTesting center keypoint: {kp_center}") - result_center = zoom("input_image", kp_center, images=images_dict) + result_center = zoom("input_image", kp_center, "square", images=images_dict) print(f"Center zoom output size: {result_center.size}") # result_center.show() # Optionally display the image print(f"\nTesting edge keypoint: {kp_edge}") - result_edge = zoom("input_image", kp_edge, images=images_dict) + result_edge = zoom("input_image", kp_edge, "square", images=images_dict) print(f"Edge zoom output size: {result_edge.size}") # result_edge.show() # Optionally display the image print("\nTesting invalid keypoint:") try: - zoom("input_image", [1200, 100], images=images_dict) + zoom("input_image", [1200, 100], "square", images=images_dict) except ValueError as e: print(f"Caught expected error: {e}") @@ -523,6 +619,7 @@ def test_parse_keypoint_wrong_type(): "name": "zoom", "image_name": "input_image", "keypoint": "[150, 250]", + "aspect_ratio_mode": "square", } parsed = parse_zoom_args(raw_valid) print(f"Parsed valid args: {parsed}") @@ -531,6 +628,7 @@ def test_parse_keypoint_wrong_type(): "name": "zoom", "image_name": "input_image", "keypoint": "[150.5, 250]", + "aspect_ratio_mode": "square", } try: parse_zoom_args(raw_invalid)