diff --git a/src/r1_vlm/datasets/text_vqa/collect_demos.py b/src/r1_vlm/datasets/text_vqa/collect_demos.py
new file mode 100644
index 0000000..23cad67
--- /dev/null
+++ b/src/r1_vlm/datasets/text_vqa/collect_demos.py
@@ -0,0 +1,284 @@
+import json
+import os
+
+import matplotlib.pyplot as plt
+import matplotlib.widgets as widgets
+from datasets import load_dataset
+from PIL import Image
+
+OUTPUT_FILE = "zoom_demos.jsonl"
+
+
+# --- Image Resizing Logic (copied from text_vqa_tool_use_r1.py) ---
+def resize_image(image):
+ """
+ Resizes the image if its longer side exceeds 1024 pixels,
+ maintaining aspect ratio.
+ """
+ # if the longer side is greater than 1024, resize it so the longer side is 1024
+ if image.mode == "L": # Handle grayscale images
+ image = image.convert("RGB")
+
+ if image.size[0] > 1024 or image.size[1] > 1024:
+ longer_side = max(image.size[0], image.size[1])
+ image = image.resize(
+ (
+ int(1024 * image.size[0] / longer_side),
+ int(1024 * image.size[1] / longer_side),
+ ),
+ Image.Resampling.LANCZOS, # Use LANCZOS for better quality
+ )
+ return image
+
+
+# --- Global State (simpler for this script) ---
+current_example = None
+original_image = None
+original_size = None
+displayed_image = None
+displayed_size = None
+selected_keypoint_original = None
+selected_keypoint_display = None
+marker_plot = None
+dataset_iterator = None
+demos_saved_count = 0
+
+# --- Matplotlib UI Elements ---
+fig, ax = plt.subplots(figsize=(10, 8)) # Make figure a bit larger
+plt.subplots_adjust(bottom=0.25) # More space for controls and title
+ax.axis("off")
+
+ax_save = plt.axes([0.81, 0.05, 0.15, 0.075])
+btn_save = widgets.Button(ax_save, "Save & Next")
+btn_save.active = False # Initially disabled
+
+ax_undo = plt.axes([0.65, 0.05, 0.15, 0.075])
+btn_undo = widgets.Button(ax_undo, "Undo Click")
+btn_undo.active = False # Initially disabled
+
+ax_skip = plt.axes([0.49, 0.05, 0.15, 0.075])
+btn_skip = widgets.Button(ax_skip, "Skip")
+
+# Text to show saved count
+ax_count_text = plt.axes([0.05, 0.05, 0.4, 0.075])
+ax_count_text.axis("off")
+count_text_obj = ax_count_text.text(
+ 0.5,
+ 0.5,
+ f"Saved: {demos_saved_count}",
+ ha="center",
+ va="center",
+ transform=ax_count_text.transAxes,
+)
+
+
+# --- Core Logic Functions ---
+def update_saved_count_display():
+ """Updates the saved count text on the plot."""
+ global demos_saved_count
+ count_text_obj.set_text(f"Saved: {demos_saved_count}")
+ plt.draw()
+
+
+def remove_marker():
+ """Removes the keypoint marker from the plot if it exists."""
+ global marker_plot
+ if marker_plot:
+ try:
+ marker_plot.remove()
+ except ValueError: # May have already been removed if plot cleared
+ pass
+ marker_plot = None
+
+
+def display_current_example():
+ """Displays the current image and question."""
+ global displayed_image, displayed_size, marker_plot
+ if current_example is None:
+ return
+
+ remove_marker() # Clear marker from previous example
+
+ # Store original and prepare display image
+ original_image_pil = current_example["image"]
+ if original_image_pil.mode == "L": # Ensure RGB
+ original_image_pil = original_image_pil.convert("RGB")
+
+ displayed_image = resize_image(original_image_pil.copy()) # Use resized for display
+ displayed_size = displayed_image.size
+
+ ax.cla() # Clear axes before drawing new image
+ ax.imshow(displayed_image)
+ answers = current_example.get("answers", []) # Get answers, default to empty list
+ answers_str = ", ".join(answers) if answers else "N/A"
+ question_text = (
+ f"Q ({current_example['question_id']}): {current_example['question']}\n"
+ f"Answers: {answers_str}"
+ )
+ ax.set_title(
+ question_text, wrap=True, fontsize=9
+ ) # Slightly smaller font for more text
+ ax.axis("off")
+ plt.draw()
+
+
+def load_next_example():
+ """Loads the next example from the dataset iterator and updates the UI."""
+ global current_example, original_image, original_size
+ global selected_keypoint_original, selected_keypoint_display, marker_plot
+ global dataset_iterator
+
+ if dataset_iterator is None:
+ print("Error: Dataset iterator not initialized.")
+ plt.close(fig)
+ return
+
+ try:
+ current_example = next(dataset_iterator)
+ original_image = current_example["image"]
+ if original_image.mode == "L": # Ensure RGB for original too
+ original_image = original_image.convert("RGB")
+ original_size = original_image.size
+
+ # Reset state for the new example
+ selected_keypoint_original = None
+ selected_keypoint_display = None
+ remove_marker() # Explicitly remove marker here too
+
+ display_current_example()
+
+ btn_save.active = False
+ btn_undo.active = False
+ plt.draw() # Update button states visually
+
+ except StopIteration:
+ print("Finished processing all examples in the dataset split.")
+ plt.close(fig)
+ except Exception as e:
+ print(f"An unexpected error occurred loading the next example: {e}")
+ plt.close(fig)
+
+
+def on_click(event):
+ """Handles mouse clicks on the image axes to select a keypoint."""
+ global selected_keypoint_original, selected_keypoint_display, marker_plot
+ global original_size, displayed_size
+
+ if event.inaxes != ax or original_size is None or displayed_size is None:
+ return # Click outside image axes or data not loaded
+
+ # Get click coordinates relative to displayed image
+ x_disp, y_disp = int(event.xdata), int(event.ydata)
+
+ # Map coordinates back to original image dimensions
+ scale_w = original_size[0] / displayed_size[0]
+ scale_h = original_size[1] / displayed_size[1]
+ x_orig = int(round(x_disp * scale_w))
+ y_orig = int(round(y_disp * scale_h))
+
+ # Clamp coordinates to ensure they are within the original image bounds
+ x_orig = max(0, min(original_size[0] - 1, x_orig))
+ y_orig = max(0, min(original_size[1] - 1, y_orig))
+
+ selected_keypoint_original = [x_orig, y_orig]
+ selected_keypoint_display = [x_disp, y_disp]
+
+ # Update marker
+ remove_marker()
+ (marker_plot,) = ax.plot(x_disp, y_disp, "r+", markersize=12, markeredgewidth=2)
+
+ # Enable Save and Undo buttons
+ btn_save.active = True
+ btn_undo.active = True
+ plt.draw()
+
+
+def save_callback(event):
+ """Saves the current example's info and keypoint, then loads the next."""
+ global demos_saved_count
+ if current_example is None or selected_keypoint_original is None:
+ print("Warning: Cannot save - no example loaded or keypoint selected.")
+ return
+
+ data = {
+ "image_id": current_example["image_id"],
+ "question_id": current_example["question_id"],
+ "keypoint": selected_keypoint_original,
+ }
+
+ try:
+ with open(OUTPUT_FILE, "a") as f:
+ f.write(json.dumps(data) + "\n")
+ print(f"Saved: QID {data['question_id']} -> {data['keypoint']}")
+ demos_saved_count += 1
+ update_saved_count_display()
+ load_next_example()
+ except IOError as e:
+ print(f"Error saving data to {OUTPUT_FILE}: {e}")
+ except Exception as e:
+ print(f"An unexpected error occurred during save: {e}")
+
+
+def undo_callback(event):
+ """Clears the selected keypoint and removes the marker."""
+ global selected_keypoint_original, selected_keypoint_display
+ remove_marker()
+ selected_keypoint_original = None
+ selected_keypoint_display = None
+ btn_save.active = False
+ btn_undo.active = False
+ plt.draw()
+
+
+def skip_callback(event):
+ """Loads the next example without saving."""
+ print("Skipped.")
+ load_next_example()
+
+
+# --- Initialization and Execution ---
+def main():
+ global dataset_iterator, demos_saved_count
+
+ print("Loading TextVQA 'train' split...")
+ try:
+ dataset = load_dataset("lmms-lab/textvqa", split="train").shuffle(seed=42)
+ dataset_iterator = iter(dataset)
+ print("Dataset loaded.")
+ except Exception as e:
+ print(f"Failed to load dataset: {e}")
+ return
+
+ # Count existing demos if file exists
+ if os.path.exists(OUTPUT_FILE):
+ try:
+ with open(OUTPUT_FILE, "r") as f:
+ demos_saved_count = sum(1 for line in f)
+ print(
+ f"Found {demos_saved_count} existing demonstrations in {OUTPUT_FILE}."
+ )
+ except Exception as e:
+ print(
+ f"Warning: Could not read existing demo count from {OUTPUT_FILE}: {e}"
+ )
+
+ # Connect callbacks
+ btn_save.on_clicked(save_callback)
+ btn_undo.on_clicked(undo_callback)
+ btn_skip.on_clicked(skip_callback)
+ fig.canvas.mpl_connect("button_press_event", on_click)
+
+ # Load the first example and show the plot
+ print("Loading first example...")
+ update_saved_count_display() # Show initial count
+ load_next_example()
+
+ if current_example: # Only show plot if first load succeeded
+ print("Starting labeling interface. Close the plot window to exit.")
+ plt.show()
+ else:
+ print("Could not load the first example. Exiting.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py b/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py
index d1b2408..7bb27f3 100644
--- a/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py
+++ b/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py
@@ -1,4 +1,7 @@
-from datasets import Dataset, DatasetDict, load_dataset
+import json
+
+from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
+from PIL import Image
from tqdm import tqdm
from r1_vlm.datasets.utils import IMAGE_PLACEHOLDER
@@ -12,7 +15,8 @@ def resize_image(image):
(
int(1024 * image.size[0] / longer_side),
int(1024 * image.size[1] / longer_side),
- )
+ ),
+ Image.Resampling.LANCZOS,
)
return image
@@ -79,7 +83,10 @@ def generate_r1_messages(example):
def create_r1_text_vqa_tool_use_dataset(
- max_examples_per_split: int | None = None, splits_to_process: list[str] = None
+ max_examples_per_split: int | None = None,
+ splits_to_process: list[str] = None,
+ zoom_demos_path: str
+ | None = "/millcreek/home/sunil/r1_vlm_bumbershoot2/r1_vlm/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl",
):
dataset = load_dataset("lmms-lab/textvqa")
@@ -91,6 +98,29 @@ def create_r1_text_vqa_tool_use_dataset(
if split not in valid_splits:
raise ValueError(f"Invalid split: {split}")
+ # --- Load Zoom Demos if path provided ---
+ zoom_demos_lookup = None
+ if zoom_demos_path and "train" in splits_to_process:
+ print(f"Loading zoom demonstrations from: {zoom_demos_path}")
+ try:
+ with open(zoom_demos_path, "r") as f:
+ demos = [json.loads(line) for line in f]
+ zoom_demos_lookup = {
+ demo["question_id"]: demo["keypoint"] for demo in demos
+ }
+ print(f"Loaded {len(zoom_demos_lookup)} zoom demonstrations.")
+ except FileNotFoundError:
+ print(
+ f"Warning: Zoom demos file not found at {zoom_demos_path}. Skipping augmentation."
+ )
+ zoom_demos_lookup = None
+ except Exception as e:
+ print(
+ f"Warning: Error loading zoom demos from {zoom_demos_path}: {e}. Skipping augmentation."
+ )
+ zoom_demos_lookup = None
+ # --- End Load Zoom Demos ---
+
processed_datasets = {}
for split in splits_to_process:
print(f"Processing {split} split...")
@@ -103,13 +133,79 @@ def create_r1_text_vqa_tool_use_dataset(
if len(examples) >= max_examples_per_split:
break
- processed_datasets[split] = Dataset.from_list(examples)
+ processed_dataset = Dataset.from_list(examples)
+
+ if split == "train" and zoom_demos_lookup is not None:
+ print(f"Augmenting '{split}' split with zoom keypoints...")
+
+ def add_zoom_keypoint(example):
+ keypoint = zoom_demos_lookup.get(example["question_id"], None)
+ return {"zoom_keypoint": keypoint}
+
+ # Add the column
+ processed_dataset = processed_dataset.map(
+ add_zoom_keypoint
+ ) # Added num_proc for potential speedup
+ print(f"Added 'zoom_keypoint' column to '{split}' split.")
+
+ # --- REORDERING START ---
+ print(
+ f"Reordering '{split}' split to place examples with keypoints first..."
+ )
+ # Filter examples with and without keypoints
+ with_keypoint_ds = processed_dataset.filter(
+ lambda example: example["zoom_keypoint"] is not None
+ )
+ without_keypoint_ds = processed_dataset.filter(
+ lambda example: example["zoom_keypoint"] is None
+ )
+
+ # --- Shuffle the 'without_keypoint_ds' ---
+ print("Shuffling examples without keypoints...")
+ without_keypoint_ds = without_keypoint_ds.shuffle(
+ seed=42
+ ) # Added shuffle with seed
+ print("Shuffled examples without keypoints.")
+ # --- End Shuffle ---
+
+ # Concatenate them in the desired order
+ processed_dataset = concatenate_datasets(
+ [with_keypoint_ds, without_keypoint_ds]
+ )
+ print(f"Reordered '{split}' split.")
+ # --- REORDERING END ---
+
+ processed_datasets[split] = processed_dataset
return DatasetDict(processed_datasets)
if __name__ == "__main__":
dataset = create_r1_text_vqa_tool_use_dataset(
- max_examples_per_split=10, splits_to_process=["train"]
+ max_examples_per_split=10000, splits_to_process=["train"]
+ )
+
+ # count how many examples have zoom_keypoint
+ num_with_keypoint = len(
+ [
+ example
+ for example in dataset["train"]
+ if example["zoom_keypoint"] is not None
+ ]
)
- print(dataset["train"][0])
+ print(f"Number of examples with zoom_keypoint: {num_with_keypoint}")
+
+ # Verify the first few examples have keypoints (if any exist)
+ print("\nVerifying first few examples have keypoints (if available):")
+ for i in range(min(5, num_with_keypoint)):
+ print(
+ f"Example {i}: QID={dataset['train'][i]['question_id']}, Keypoint={dataset['train'][i]['zoom_keypoint']}"
+ )
+
+ # Verify an example after the keypoint block has None (if applicable)
+ if num_with_keypoint < len(dataset["train"]):
+ print("\nVerifying example after the keypoint block:")
+ idx_after = num_with_keypoint # First example expected to have None
+ print(
+ f"Example {idx_after}: QID={dataset['train'][idx_after]['question_id']}, Keypoint={dataset['train'][idx_after]['zoom_keypoint']}"
+ )
diff --git a/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl b/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl
new file mode 100644
index 0000000..cb2d4d0
--- /dev/null
+++ b/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl
@@ -0,0 +1,101 @@
+{"image_id": "0054c91397f2fe05", "question_id": 0, "keypoint": [352, 312]}
+{"image_id": "005635e119b9f32f", "question_id": 1, "keypoint": [561, 322]}
+{"image_id": "005635e119b9f32f", "question_id": 2, "keypoint": [80, 326]}
+{"image_id": "00685bc495504d61", "question_id": 3, "keypoint": [119, 1010]}
+{"image_id": "00685bc495504d61", "question_id": 4, "keypoint": [159, 567]}
+{"image_id": "006d10667d17b924", "question_id": 5, "keypoint": [445, 461]}
+{"image_id": "006d10667d17b924", "question_id": 6, "keypoint": [442, 140]}
+{"image_id": "00c359f294f7dcd9", "question_id": 7, "keypoint": [373, 399]}
+{"image_id": "00c359f294f7dcd9", "question_id": 8, "keypoint": [382, 400]}
+{"image_id": "0122c5279a501df2", "question_id": 9, "keypoint": [907, 247]}
+{"image_id": "0122c5279a501df2", "question_id": 10, "keypoint": [300, 275]}
+{"image_id": "015c537f722d115e", "question_id": 11, "keypoint": [646, 362]}
+{"image_id": "0202faf23a9aae11", "question_id": 12, "keypoint": [169, 601]}
+{"image_id": "02b2daa20b00f5d3", "question_id": 13, "keypoint": [162, 195]}
+{"image_id": "02b2daa20b00f5d3", "question_id": 14, "keypoint": [540, 322]}
+{"image_id": "02da7a42e971e3cb", "question_id": 15, "keypoint": [520, 457]}
+{"image_id": "02da7a42e971e3cb", "question_id": 16, "keypoint": [518, 460]}
+{"image_id": "d709940b064b2ea7", "question_id": 4782, "keypoint": [525, 487]}
+{"image_id": "d7e15efeee988b35", "question_id": 23953, "keypoint": [360, 631]}
+{"image_id": "258dcb6be2766e6b", "question_id": 10485, "keypoint": [54, 296]}
+{"image_id": "ae0e286ea6301919", "question_id": 6787, "keypoint": [854, 483]}
+{"image_id": "a9085b21fafeb7e0", "question_id": 12304, "keypoint": [357, 519]}
+{"image_id": "0876802af9ca4f53", "question_id": 2069, "keypoint": [295, 537]}
+{"image_id": "51153c1aeeba2142", "question_id": 21859, "keypoint": [568, 533]}
+{"image_id": "bb4f9f6fa8be3784", "question_id": 23499, "keypoint": [573, 303]}
+{"image_id": "0cc86f2bc0023406", "question_id": 18672, "keypoint": [290, 534]}
+{"image_id": "e47a56d412750c3e", "question_id": 14447, "keypoint": [383, 360]}
+{"image_id": "3109e0ab52305483", "question_id": 5725, "keypoint": [336, 263]}
+{"image_id": "65277ae6a3eb6c8a", "question_id": 14311, "keypoint": [478, 644]}
+{"image_id": "3678e10d1ee15709", "question_id": 2707, "keypoint": [800, 190]}
+{"image_id": "04f9cc3db7d8da24", "question_id": 1303, "keypoint": [646, 456]}
+{"image_id": "99c2dd0295e97a04", "question_id": 30730, "keypoint": [92, 417]}
+{"image_id": "fcab06f91d7da18d", "question_id": 33237, "keypoint": [414, 591]}
+{"image_id": "61a5eb30c602a678", "question_id": 3308, "keypoint": [461, 508]}
+{"image_id": "3bc293efa0203314", "question_id": 18512, "keypoint": [370, 471]}
+{"image_id": "4a8830e36c78f5e7", "question_id": 14061, "keypoint": [553, 267]}
+{"image_id": "beed9557550278a3", "question_id": 9481, "keypoint": [545, 327]}
+{"image_id": "4c858f86738a72fd", "question_id": 1534, "keypoint": [666, 348]}
+{"image_id": "104bf2461f9962a8", "question_id": 5445, "keypoint": [452, 532]}
+{"image_id": "5b41dbe96cd05104", "question_id": 26950, "keypoint": [689, 363]}
+{"image_id": "8dad2e89e4ad9839", "question_id": 6536, "keypoint": [58, 189]}
+{"image_id": "ad899b711c4e5613", "question_id": 6784, "keypoint": [119, 406]}
+{"image_id": "0927271c7f9f3b9e", "question_id": 30035, "keypoint": [153, 324]}
+{"image_id": "cbe00b9c28494af5", "question_id": 17311, "keypoint": [444, 538]}
+{"image_id": "2d4f3317cfe522c8", "question_id": 21210, "keypoint": [438, 598]}
+{"image_id": "8568bc5fc7abbb33", "question_id": 16729, "keypoint": [611, 114]}
+{"image_id": "a70b30b1806e1226", "question_id": 23157, "keypoint": [732, 293]}
+{"image_id": "008b8556a73601ae", "question_id": 17883, "keypoint": [384, 830]}
+{"image_id": "1b8bf6437a221a10", "question_id": 1412, "keypoint": [461, 222]}
+{"image_id": "4f8d13bce21b6ada", "question_id": 34364, "keypoint": [397, 359]}
+{"image_id": "1a47acd8e0d07441", "question_id": 28380, "keypoint": [540, 258]}
+{"image_id": "143d9adc0368b1dd", "question_id": 17992, "keypoint": [508, 249]}
+{"image_id": "bba917b91c4c46dd", "question_id": 12361, "keypoint": [961, 402]}
+{"image_id": "0204574848b0bf05", "question_id": 29549, "keypoint": [363, 521]}
+{"image_id": "083ea2d88ee71e24", "question_id": 31564, "keypoint": [211, 763]}
+{"image_id": "2f6621dfaa4e4672", "question_id": 27459, "keypoint": [745, 452]}
+{"image_id": "d929fd569dbae5fe", "question_id": 4809, "keypoint": [516, 882]}
+{"image_id": "20bb636326491022", "question_id": 20971, "keypoint": [503, 63]}
+{"image_id": "8df181c0899df120", "question_id": 3843, "keypoint": [446, 204]}
+{"image_id": "923545f7c0785426", "question_id": 34026, "keypoint": [505, 864]}
+{"image_id": "89f64c747cf016ec", "question_id": 29772, "keypoint": [531, 543]}
+{"image_id": "ba0c5d07cd0803fa", "question_id": 25480, "keypoint": [495, 92]}
+{"image_id": "6ce0bbbee320e928", "question_id": 12134, "keypoint": [134, 686]}
+{"image_id": "a760819a4cd02bc5", "question_id": 32792, "keypoint": [479, 661]}
+{"image_id": "04b99754c51e6092", "question_id": 1296, "keypoint": [219, 377]}
+{"image_id": "061d8e57615d98c4", "question_id": 12598, "keypoint": [140, 536]}
+{"image_id": "f3d9c7c15556813a", "question_id": 24298, "keypoint": [890, 89]}
+{"image_id": "023463e93c2fa36e", "question_id": 24935, "keypoint": [631, 39]}
+{"image_id": "6b1bad9e14c54eff", "question_id": 13167, "keypoint": [242, 498]}
+{"image_id": "14297a09cd0704df", "question_id": 2226, "keypoint": [444, 480]}
+{"image_id": "9f9e1b7cdf8cc9d5", "question_id": 25407, "keypoint": [401, 49]}
+{"image_id": "0012836b7cb7e1a8", "question_id": 29531, "keypoint": [822, 462]}
+{"image_id": "2e975f0231b49855", "question_id": 28647, "keypoint": [522, 286]}
+{"image_id": "ac32d325dbe4cf76", "question_id": 17033, "keypoint": [674, 28]}
+{"image_id": "a6628a4b35d3382b", "question_id": 6722, "keypoint": [571, 426]}
+{"image_id": "8c48baf8904eeca4", "question_id": 18305, "keypoint": [413, 333]}
+{"image_id": "2124bb3f0620f8b0", "question_id": 11024, "keypoint": [578, 460]}
+{"image_id": "5091e1f3b5a27a0f", "question_id": 14142, "keypoint": [392, 342]}
+{"image_id": "eacf4d74fad54bac", "question_id": 31014, "keypoint": [883, 279]}
+{"image_id": "3c570608f258bcfd", "question_id": 21503, "keypoint": [394, 793]}
+{"image_id": "cb2aa50349915d64", "question_id": 23771, "keypoint": [325, 147]}
+{"image_id": "92d76e2fdaf6f07d", "question_id": 3918, "keypoint": [639, 914]}
+{"image_id": "06bc833b6794cb6d", "question_id": 18174, "keypoint": [399, 407]}
+{"image_id": "0fc054810d8ab304", "question_id": 8424, "keypoint": [274, 37]}
+{"image_id": "423165577f3f52ce", "question_id": 30439, "keypoint": [575, 348]}
+{"image_id": "12bc08d2fce21a45", "question_id": 10451, "keypoint": [919, 286]}
+{"image_id": "e101f9601e0d971c", "question_id": 667, "keypoint": [89, 416]}
+{"image_id": "5894cb4708630e4c", "question_id": 27527, "keypoint": [202, 624]}
+{"image_id": "df824b1152781395", "question_id": 20093, "keypoint": [862, 995]}
+{"image_id": "fecd1e75b8deeca0", "question_id": 17756, "keypoint": [182, 183]}
+{"image_id": "d81bf036569df61d", "question_id": 8999, "keypoint": [743, 125]}
+{"image_id": "813f257ac2f2bfe2", "question_id": 25334, "keypoint": [141, 724]}
+{"image_id": "dc9d9de141534563", "question_id": 27148, "keypoint": [494, 372]}
+{"image_id": "56853543d7ae1235", "question_id": 13047, "keypoint": [362, 77]}
+{"image_id": "0919842b32156e39", "question_id": 9127, "keypoint": [650, 264]}
+{"image_id": "7112b28a96993e28", "question_id": 9994, "keypoint": [218, 677]}
+{"image_id": "5ddd2a8a3c2328d6", "question_id": 19626, "keypoint": [693, 457]}
+{"image_id": "765bb8ab1e6a0eff", "question_id": 15176, "keypoint": [536, 93]}
+{"image_id": "4b7044e659d47df3", "question_id": 31928, "keypoint": [322, 107]}
+{"image_id": "476198cff16798d6", "question_id": 25176, "keypoint": [460, 420]}
+{"image_id": "e0b856b9fe929346", "question_id": 4906, "keypoint": [433, 561]}
diff --git a/src/r1_vlm/environments/multistep_vision_env.py b/src/r1_vlm/environments/multistep_vision_env.py
index 825cf7a..9cc6a72 100644
--- a/src/r1_vlm/environments/multistep_vision_env.py
+++ b/src/r1_vlm/environments/multistep_vision_env.py
@@ -218,7 +218,9 @@ def clean_messages_for_logging(messages):
)
for image in images:
try:
- imgcat.imgcat(image)
+ image_copy = image.copy()
+ image_copy.convert("RGB")
+ imgcat.imgcat(image_copy)
except Exception as e:
print(
f"Caught failed imgcat call for image. As this is just a debugging print, we will not raise an error. {e}"
diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py
index 344947f..d041d94 100644
--- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py
+++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py
@@ -1,3 +1,4 @@
+import math
import re
from typing import Any, Callable
@@ -292,8 +293,17 @@ def get_dataset(
output_datasets[split] = dataset_split
+ # only keep the examples with keypoints from the train set for initial finetuning
+ # if "train" in splits:
+ # print("Filtering train set to only include examples with zoom keypoints.")
+ # output_datasets["train"] = output_datasets["train"].filter(
+ # lambda x: x["zoom_keypoint"] is not None
+ # )
+ # print(f"After filtering, {len(output_datasets['train'])} examples remain.")
+
+ # shuffle the train set
if "train" in splits:
- output_datasets["train"].shuffle()
+ output_datasets["train"] = output_datasets["train"].shuffle()
return output_datasets
@@ -341,6 +351,11 @@ def get_reward_weights(self) -> list[float]:
# consistent high reward for getting the answer right
schedule = 1.0
reward_weights.append(schedule)
+
+ elif reward_function.__name__ == "zoom_keypoint_reward_func":
+ # consistent high reward for using the zoom keypoint
+ schedule = 1.0
+ reward_weights.append(schedule)
else:
raise ValueError(
f"Unknown reward function: {reward_function.__name__} encountered in get_reward_weights"
@@ -548,10 +563,104 @@ def correct_answer_reward_func(
return correctness_rewards
+ def extract_keypoint(text):
+ # Pattern to match [x, y] where x and y are integers, allowing multiline
+ pattern = r"\[\s*(\d+)\s*,\s*(\d+)\s*\]"
+ match = re.search(pattern, text, re.DOTALL)
+
+ if match:
+ x = int(match.group(1))
+ y = int(match.group(2))
+ return x, y
+ return None
+
+ def zoom_keypoint_reward_func(
+ prompts, completions, completions_messages, **kwargs
+ ):
+ merged_completion_conversations = MultistepVisionEnv.preprocess_messages(
+ prompts_messages=prompts, completions_messages=completions_messages
+ )
+
+ # extract the image size from the prompt
+ prompt = prompts[0]
+ image_content = prompt[1]["content"][1]
+ image = image_content["image"]
+ image_width = image.width
+ image_height = image.height
+
+ # extract the zoom keypoint from the completion, if it exists
+ conv_keypoints = []
+ for conv in merged_completion_conversations:
+ keypoint_found = False
+ for msg in conv:
+ if msg["role"] == "assistant":
+ parsed = self.parser.parse(msg["content"][0]["text"])
+ if hasattr(parsed, "tool") and parsed.tool is not None:
+ tool_content = parsed.tool
+ if "name: zoom" in tool_content:
+ conv_keypoint = extract_keypoint(tool_content)
+ conv_keypoints.append(conv_keypoint)
+ keypoint_found = True
+ break
+ # if no tool call with a zoom keypoint is found, append None
+ if not keypoint_found:
+ conv_keypoints.append(None)
+
+ print(f"conv_keypoints: {conv_keypoints}")
+
+ # check if the example has a zoom keypoint - this is a list of Group size tuples, each tuple is identical
+ zoom_keypoints = kwargs.get("zoom_keypoint", None)
+
+ # if this example does not have a zoom keypoint, return 0 reward for all completions as there's no way to verify the result
+ if zoom_keypoints is None:
+ return [0.0] * len(merged_completion_conversations)
+
+ zoom_keypoint = zoom_keypoints[0]
+
+ # if the zoom keypoint is not present, return 0 reward for all completions as there's no way to verify the result
+ if zoom_keypoint is None:
+ return [0.0] * len(merged_completion_conversations)
+
+ normalized_zoom_keypoint = (
+ zoom_keypoint[0] / image_width,
+ zoom_keypoint[1] / image_height,
+ )
+ print(f"normalized_zoom_keypoint: {normalized_zoom_keypoint}")
+
+ normalized_conv_keypoints = []
+ for conv_keypoint in conv_keypoints:
+ if conv_keypoint is None:
+ normalized_conv_keypoints.append(None)
+ else:
+ normalized_conv_keypoints.append(
+ (
+ conv_keypoint[0] / image_width,
+ conv_keypoint[1] / image_height,
+ )
+ )
+
+ rewards = []
+ for normalized_conv_keypoint in normalized_conv_keypoints:
+ if normalized_conv_keypoint is None:
+ rewards.append(0.0)
+ else:
+ # compute the normalized euclidean distance between the keypoint in the tool call and the zoom keypoint
+ distance = math.sqrt(
+ (normalized_conv_keypoint[0] - normalized_zoom_keypoint[0]) ** 2
+ + (normalized_conv_keypoint[1] - normalized_zoom_keypoint[1])
+ ** 2
+ )
+ # clip the reward at 0.0, as this value could be negative (1 - root 2) worst case
+ reward = max(1 - distance, 0.0)
+ rewards.append(reward)
+
+ return rewards
+
return [
format_reward_func,
tool_execution_reward_func,
correct_answer_reward_func,
+ zoom_keypoint_reward_func,
]
def log_metrics(self, conversations, completions_text, completion_messages):
@@ -560,10 +669,16 @@ def log_metrics(self, conversations, completions_text, completion_messages):
completions_with_tool_use = 0
completions_with_zoom_use = 0
+ use_horizontal_zoom = 0
+ use_vertical_zoom = 0
+ use_square_zoom = 0
for completion in completions_text:
tool_use_regex = r"(.*?)"
zoom_use_string = "name: zoom"
+ horizontal_zoom_use_string = "aspect_ratio_mode: horizontal"
+ vertical_zoom_use_string = "aspect_ratio_mode: vertical"
+ square_zoom_use_string = "aspect_ratio_mode: square"
tool_matches = re.findall(tool_use_regex, completion, re.DOTALL)
if tool_matches:
completions_with_tool_use += 1
@@ -571,8 +686,15 @@ def log_metrics(self, conversations, completions_text, completion_messages):
if zoom_use_string in tool_content:
completions_with_zoom_use += 1
+ if horizontal_zoom_use_string in tool_content:
+ use_horizontal_zoom += 1
+ if vertical_zoom_use_string in tool_content:
+ use_vertical_zoom += 1
+ if square_zoom_use_string in tool_content:
+ use_square_zoom += 1
+
print(
- f"There are {len(completions_text)} completions, {completions_with_tool_use} of which attempt to use a tool, {completions_with_zoom_use} of which attempt to use zoom"
+ f"There are {len(completions_text)} completions, \n {completions_with_tool_use} of which attempt to use a tool, \n {completions_with_zoom_use} of which attempt to use zoom. \n {use_horizontal_zoom} of which attempt to use horizontal zoom, \n {use_vertical_zoom} of which attempt to use vertical zoom, \n {use_square_zoom} of which attempt to use square zoom"
)
num_completions = len(completions_text)
@@ -582,6 +704,9 @@ def log_metrics(self, conversations, completions_text, completion_messages):
return {
"tool_use_proportion": tool_use_proportion,
"zoom_use_proportion": zoom_use_proportion,
+ "use_horizontal_zoom": use_horizontal_zoom,
+ "use_vertical_zoom": use_vertical_zoom,
+ "use_square_zoom": use_square_zoom,
}
diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py
index acd678d..51bb215 100644
--- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py
+++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py
@@ -90,8 +90,12 @@ def find_target_linear_names(
def train():
+ checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-text-vqa-distance-reward-zoom-aspect-ratio-may4-3B/checkpoint-500"
+
model, peft_config, processor, model_config, gradient_checkpointing = (
- load_model_and_processor(gradient_checkpointing=True, use_peft=False)
+ load_model_and_processor(
+ model_name_or_path=checkpoint, gradient_checkpointing=True, use_peft=False
+ )
)
print("loaded model")
@@ -110,7 +114,7 @@ def train():
training_args = GRPOConfig(
model_init_kwargs=model_config,
# save path on the runpod instance
- output_dir="vlm-r1-text-vqa-tool-use-no-tool-reward-may4-3B",
+ output_dir="vlm-r1-text-vqa-post-demo-ft-may5-3B",
# increase learning rate for PEFT - 1e-4
learning_rate=1e-4 if peft_config is not None else 1e-6,
max_grad_norm=1.0,
@@ -120,7 +124,7 @@ def train():
logging_steps=1,
save_steps=50,
save_total_limit=10,
- num_train_epochs=1,
+ num_train_epochs=100,
per_device_train_batch_size=1,
num_generations=3,
gradient_accumulation_steps=4,
@@ -158,6 +162,8 @@ def train():
env=vf_env,
peft_config=peft_config,
guided_regex=None,
+ # shuffling is disabled because we want to start with the keypoint examples.
+ shuffle_dataset=False,
)
trainer.train()
diff --git a/src/r1_vlm/tools/zoom.py b/src/r1_vlm/tools/zoom.py
index 534481d..74809dd 100644
--- a/src/r1_vlm/tools/zoom.py
+++ b/src/r1_vlm/tools/zoom.py
@@ -1,5 +1,6 @@
# generic zoom tool
import json
+import math # Added import
from typing import List, Tuple
import pytest
@@ -9,31 +10,33 @@
def calculate_crop_coordinates(
- keypoint: List[int], image_size: Tuple[int, int], target_size: int = 300
+ keypoint: List[int],
+ image_size: Tuple[int, int],
+ aspect_ratio_mode: str,
+ base_target_size: int = 300,
) -> Tuple[int, int, int, int]:
"""
- Calculates the crop coordinates for a square box centered around a keypoint.
+ Calculates the crop coordinates for a box centered around a keypoint,
+ with a specified aspect ratio.
+ The box aims for an area roughly equal to base_target_size * base_target_size.
If the target box extends beyond the image boundaries, the box is shifted
- to stay within the image, maintaining the target size. The keypoint will
- no longer be centered in this case.
-
- The returned coordinates are suitable for use with PIL's Image.crop method,
- following the (left, upper, right, lower) format where 'right' and 'lower'
- are exclusive.
+ to stay within the image, maintaining the target size and aspect ratio.
+ The keypoint will no longer be centered in this case.
Args:
keypoint: A list or tuple [x, y] representing the desired center coordinates.
image_size: A tuple (width, height) of the original image.
- target_size: The desired width and height of the square crop box.
+ aspect_ratio_mode: The desired aspect ratio ("square", "horizontal", "vertical").
+ base_target_size: The base dimension used for area calculation.
Returns:
A tuple (crop_left, crop_upper, crop_right, crop_lower) defining the
region to crop from the original image.
Raises:
- ValueError: If the keypoint is outside the image boundaries or inputs
- are invalid types/formats.
+ ValueError: If inputs are invalid types/formats, keypoint is outside
+ the image boundaries, or aspect_ratio_mode is invalid.
"""
# --- Input Validation ---
if (
@@ -54,8 +57,13 @@ def calculate_crop_coordinates(
"Error:image_size must be a tuple of two positive integers (width, height)"
)
- if not isinstance(target_size, int) or target_size <= 0:
- raise ValueError("Error: target_size must be a positive integer")
+ if not isinstance(base_target_size, int) or base_target_size <= 0:
+ raise ValueError("Error: base_target_size must be a positive integer")
+
+ if aspect_ratio_mode not in ["square", "horizontal", "vertical"]:
+ raise ValueError(
+ "Error: aspect_ratio_mode must be 'square', 'horizontal', or 'vertical'"
+ )
x, y = keypoint
width, height = image_size
@@ -66,22 +74,40 @@ def calculate_crop_coordinates(
f"(width={width}, height={height})"
)
+ # --- Calculate Target Crop Dimensions based on Aspect Ratio ---
+ sqrt2 = math.sqrt(2)
+ if aspect_ratio_mode == "square":
+ target_crop_width = base_target_size
+ target_crop_height = base_target_size
+ elif aspect_ratio_mode == "horizontal": # 2:1
+ target_crop_width = round(base_target_size * sqrt2)
+ target_crop_height = round(base_target_size / sqrt2)
+ else: # aspect_ratio_mode == "vertical" (1:2)
+ target_crop_width = round(base_target_size / sqrt2)
+ target_crop_height = round(base_target_size * sqrt2)
+
+ # Ensure dimensions are at least 1
+ target_crop_width = max(1, target_crop_width)
+ target_crop_height = max(1, target_crop_height)
+
# --- Calculate Crop Box ---
- half_size = target_size // 2
+ # Calculate half-sizes (integer division)
+ half_width = target_crop_width // 2
+ half_height = target_crop_height // 2
# Calculate the ideal top-left corner to center the box
- ideal_x_min = x - half_size
- ideal_y_min = y - half_size
+ ideal_x_min = x - half_width
+ ideal_y_min = y - half_height
# Initialize the actual top-left corner
crop_left = ideal_x_min
crop_upper = ideal_y_min
# Adjust if the box extends beyond the right or bottom boundaries
- if crop_left + target_size > width:
- crop_left = width - target_size
- if crop_upper + target_size > height:
- crop_upper = height - target_size
+ if crop_left + target_crop_width > width:
+ crop_left = width - target_crop_width
+ if crop_upper + target_crop_height > height:
+ crop_upper = height - target_crop_height
# Adjust if the box extends beyond the left or top boundaries
if crop_left < 0:
@@ -91,13 +117,11 @@ def calculate_crop_coordinates(
# Calculate the final right and lower bounds for PIL crop (exclusive)
# Clamp the right/lower bounds to the image dimensions, important if
- # target_size > image width/height.
- crop_right = min(crop_left + target_size, width)
- crop_lower = min(crop_upper + target_size, height)
+ # target crop dimensions > image width/height.
+ crop_right = min(crop_left + target_crop_width, width)
+ crop_lower = min(crop_upper + target_crop_height, height)
# Ensure left/upper are also clamped in case target_size > image dimensions
- # which might lead to negative values after boundary adjustments.
- # The previous checks already handle this, but being explicit doesn't hurt.
crop_left = max(0, crop_left)
crop_upper = max(0, crop_upper)
@@ -107,29 +131,48 @@ def calculate_crop_coordinates(
def zoom(
image_name: str,
keypoint: list[int],
+ aspect_ratio_mode: str,
**kwargs,
) -> Image.Image:
"""
- Returns an image zoomed in on the specified keypoint.
+ Returns an image zoomed in on the specified keypoint with a given aspect ratio.
This is useful to see a region of interest with higher clarity, like for reading text or identifying objects.
+ The choice of aspect ration is helpful depending on the task.
+ If you are looking at horizontal text or some other wide region, you might use "horizontal" aspect ratio.
+ If you are looking at vertical text or some other tall region, you might use "vertical" aspect ratio.
+ If you are looking at a region more generally, you might use "square" aspect ratio.
+
Args:
image_name: str, the name of the image to zoom in on. Can only be called on the "input_image".
keypoint: list[int], the [x, y] coordinates of the point to center the zoom on. Must be within the image boundaries.
+ aspect_ratio_mode: str, the desired aspect ratio for the zoom window. Must be one of "square", "horizontal", "vertical".
+ "square" is 1:1, "horizontal" is 2:1 (wider), "vertical" is 1:2 (taller).
Returns:
- The image zoomed in on the specified keypoint.
+ The image zoomed in on the specified keypoint with the requested aspect ratio.
Examples:
name: zoom
image_name: input_image
+ aspect_ratio_mode: square
keypoint: [500, 400]
+
+
+
+ name: zoom
+ image_name: input_image
+ aspect_ratio_mode: horizontal
+ keypoint: [23, 44]
+
name: zoom
image_name: input_image
- keypoint: [50, 40]
+ aspect_ratio_mode: vertical
+ keypoint: [723, 461]
+
"""
# get and validate the image
@@ -146,27 +189,50 @@ def zoom(
f"Error: Image {image_name} is not the input_image. This tool can only be called on the input_image."
)
- width, height = image.size
- # size we will crop to
- target_size = 300
-
- # we will resize the image to the larger dimension of the input image, or 400 if it's a smaller image
- resize_size = max(width, height, 400)
+ original_width, original_height = image.size
+ # base size for cropping area calculation
+ base_crop_target_size = 300
- # Validate keypoint and calculate crop box using the helper function
- # ValueError will be raised by the helper if keypoint is invalid
- crop_box = calculate_crop_coordinates(keypoint, (width, height), target_size)
+ # Validate aspect_ratio_mode and keypoint, then calculate crop box
+ crop_box = calculate_crop_coordinates(
+ keypoint,
+ (original_width, original_height),
+ aspect_ratio_mode,
+ base_target_size=base_crop_target_size,
+ )
# crop the image to the calculated box
cropped_image = image.crop(crop_box)
-
- # Resize the cropped image to the target size
- output_image = cropped_image.resize(
- (resize_size, resize_size), Image.Resampling.LANCZOS
+ crop_w, crop_h = cropped_image.size
+
+ # If crop resulted in zero size (e.g., invalid crop box calculation edge case)
+ if crop_w == 0 or crop_h == 0:
+ raise ValueError("Error: Cropped image has zero dimension")
+
+ # Determine the base dimension for resizing the output - use original image dims or 400 min
+ base_resize_dim = max(original_width, original_height, 400)
+
+ # Calculate the final output size, preserving aspect ratio of the cropped image
+ # Scale so the LARGER dimension of the output matches base_resize_dim
+ if crop_w >= crop_h: # Wider or square crop
+ output_w = base_resize_dim
+ output_h = round(output_w * crop_h / crop_w) # Maintain aspect ratio
+ else: # Taller crop
+ output_h = base_resize_dim
+ output_w = round(output_h * crop_w / crop_h) # Maintain aspect ratio
+
+ # Ensure minimum dimension is 1
+ output_w = max(1, output_w)
+ output_h = max(1, output_h)
+ output_size = (output_w, output_h)
+
+ # Resize the cropped image to the calculated output size, preserving aspect ratio
+ output_image = cropped_image.resize(output_size, Image.Resampling.LANCZOS)
+
+ print(
+ f"Zoom tool cropped to {crop_w}x{crop_h}, resized output to: {output_image.size}"
)
- print(f"Zoom tool output image size: {output_image.size}")
-
return output_image
@@ -174,7 +240,7 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs:
"""
Parses raw string arguments for the zoom tool (keypoint version).
- Expects keys: 'name', 'image_name', 'keypoint'.
+ Expects keys: 'name', 'image_name', 'keypoint', 'aspect_ratio_mode'.
Converts 'keypoint' from a JSON string to a list of integers.
Detailed validation of values (e.g., keypoint coordinates)
is deferred to the zoom function itself.
@@ -184,14 +250,20 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs:
Returns:
A dictionary containing the arguments with basic type conversions applied,
- ready for the zoom function. Keys: 'image_name', 'keypoint'.
+ ready for the zoom function. Keys: 'image_name', 'keypoint', 'aspect_ratio_mode'.
Raises:
ValueError: If required keys are missing, extra keys are present,
- or basic type conversion fails (e.g., 'keypoint' is not valid JSON
- or doesn't result in a list of two integers).
+ basic type conversion fails (e.g., 'keypoint' is not valid JSON
+ or doesn't result in a list of two integers), or
+ 'aspect_ratio_mode' has an invalid value.
"""
- required_keys = {"name", "image_name", "keypoint"}
+ required_keys = {
+ "name",
+ "image_name",
+ "keypoint",
+ "aspect_ratio_mode",
+ } # Added aspect_ratio_mode
actual_keys = set(raw_args.keys())
# 1. Check for Missing Keys
@@ -224,9 +296,12 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs:
raise ValueError(
"Error: Both elements in 'keypoint' list must be integers."
)
-
typed_args["keypoint"] = keypoint_list
+ # Validate aspect_ratio_mode
+ aspect_mode = raw_args["aspect_ratio_mode"]
+ typed_args["aspect_ratio_mode"] = aspect_mode
+
except json.JSONDecodeError:
raise ValueError(
f"Error: Invalid JSON format for 'keypoint': '{raw_args['keypoint']}'"
@@ -254,7 +329,7 @@ def test_calc_centered(sample_image):
img_size = sample_image["input_image"].size
kp = [500, 400]
ts = 300
- box = calculate_crop_coordinates(kp, img_size, ts)
+ box = calculate_crop_coordinates(kp, img_size, "square", ts)
assert box == (350, 250, 650, 550)
@@ -262,7 +337,7 @@ def test_calc_top_left(sample_image):
img_size = sample_image["input_image"].size
kp = [50, 40]
ts = 300
- box = calculate_crop_coordinates(kp, img_size, ts)
+ box = calculate_crop_coordinates(kp, img_size, "square", ts)
assert box == (0, 0, 300, 300)
@@ -270,7 +345,7 @@ def test_calc_bottom_right(sample_image):
img_size = sample_image["input_image"].size
kp = [950, 750]
ts = 300
- box = calculate_crop_coordinates(kp, img_size, ts)
+ box = calculate_crop_coordinates(kp, img_size, "square", ts)
assert box == (700, 500, 1000, 800)
@@ -278,7 +353,7 @@ def test_calc_right_middle(sample_image):
img_size = sample_image["input_image"].size
kp = [950, 400]
ts = 300
- box = calculate_crop_coordinates(kp, img_size, ts)
+ box = calculate_crop_coordinates(kp, img_size, "square", ts)
assert box == (700, 250, 1000, 550)
@@ -286,7 +361,7 @@ def test_calc_small_image(sample_image):
small_img_size = (200, 150)
kp = [100, 75]
ts = 300
- box = calculate_crop_coordinates(kp, small_img_size, ts)
+ box = calculate_crop_coordinates(kp, small_img_size, "square", ts)
assert box == (0, 0, 200, 150)
@@ -294,30 +369,30 @@ def test_calc_invalid_keypoint_coords(sample_image):
img_size = sample_image["input_image"].size
ts = 300
with pytest.raises(ValueError, match="outside the image boundaries"):
- calculate_crop_coordinates([1000, 400], img_size, ts)
+ calculate_crop_coordinates([1000, 400], img_size, "square", ts)
with pytest.raises(ValueError, match="outside the image boundaries"):
- calculate_crop_coordinates([-1, 400], img_size, ts)
+ calculate_crop_coordinates([-1, 400], img_size, "square", ts)
with pytest.raises(ValueError, match="outside the image boundaries"):
- calculate_crop_coordinates([500, 800], img_size, ts)
+ calculate_crop_coordinates([500, 800], img_size, "square", ts)
with pytest.raises(ValueError, match="outside the image boundaries"):
- calculate_crop_coordinates([500, -1], img_size, ts)
+ calculate_crop_coordinates([500, -1], img_size, "square", ts)
def test_calc_invalid_input_types():
with pytest.raises(ValueError, match="keypoint must be a list or tuple"):
- calculate_crop_coordinates("[100, 100]", (500, 500), 300)
+ calculate_crop_coordinates("[100, 100]", (500, 500), "square", 300)
with pytest.raises(ValueError, match="keypoint must be a list or tuple"):
- calculate_crop_coordinates([100], (500, 500), 300)
+ calculate_crop_coordinates([100], (500, 500), "square", 300)
with pytest.raises(ValueError, match="keypoint must be a list or tuple"):
- calculate_crop_coordinates([100.0, 100], (500, 500), 300)
+ calculate_crop_coordinates([100.0, 100], (500, 500), "square", 300)
with pytest.raises(ValueError, match="image_size must be a tuple"):
- calculate_crop_coordinates([100, 100], [500, 500], 300)
+ calculate_crop_coordinates([100, 100], [500, 500], "square", 300)
with pytest.raises(ValueError, match="image_size must be a tuple"):
- calculate_crop_coordinates([100, 100], (500, 0), 300)
+ calculate_crop_coordinates([100, 100], (500, 0), "square", 300)
with pytest.raises(ValueError, match="target_size must be a positive integer"):
- calculate_crop_coordinates([100, 100], (500, 500), 0)
+ calculate_crop_coordinates([100, 100], (500, 500), "square", 0)
with pytest.raises(ValueError, match="target_size must be a positive integer"):
- calculate_crop_coordinates([100, 100], (500, 500), -100)
+ calculate_crop_coordinates([100, 100], (500, 500), "square", -100)
# --- Tests for zoom function ---
@@ -325,7 +400,7 @@ def test_calc_invalid_input_types():
def test_zoom_invalid_image_name(sample_image):
with pytest.raises(ValueError, match="not found"):
- zoom("nonexistent_image", [100, 100], images=sample_image)
+ zoom("nonexistent_image", [100, 100], "square", images=sample_image)
# Add test for incorrect image name usage (e.g., "test_image" instead of "input_image")
@@ -334,6 +409,7 @@ def test_zoom_incorrect_image_name_usage(sample_image):
zoom(
"test_image",
[100, 100],
+ "square",
images={"test_image": sample_image["input_image"]},
)
@@ -342,44 +418,50 @@ def test_zoom_incorrect_image_name_usage(sample_image):
def test_zoom_invalid_keypoint_format(sample_image):
with pytest.raises(ValueError, match="keypoint must be a list or tuple"):
zoom(
- "input_image", "[100, 100]", images=sample_image
+ "input_image", "[100, 100]", "square", images=sample_image
) # Pass string instead of list
with pytest.raises(ValueError, match="keypoint must be a list or tuple"):
- zoom("input_image", [100], images=sample_image) # Pass list with 1 element
+ zoom(
+ "input_image", [100], "square", images=sample_image
+ ) # Pass list with 1 element
with pytest.raises(ValueError, match="keypoint must be a list or tuple"):
- zoom("input_image", [100.0, 100], images=sample_image) # Pass float
+ zoom("input_image", [100.0, 100], "square", images=sample_image) # Pass float
# Test keypoint outside image bounds (should be caught by calculate_crop_coordinates)
def test_zoom_keypoint_out_of_bounds(sample_image):
img_size = sample_image["input_image"].size
with pytest.raises(ValueError, match="outside the image boundaries"):
- zoom("input_image", [img_size[0], 100], images=sample_image) # x = width
+ zoom(
+ "input_image", [img_size[0], 100], "square", images=sample_image
+ ) # x = width
with pytest.raises(ValueError, match="outside the image boundaries"):
- zoom("input_image", [-1, 100], images=sample_image) # x < 0
+ zoom("input_image", [-1, 100], "square", images=sample_image) # x < 0
with pytest.raises(ValueError, match="outside the image boundaries"):
- zoom("input_image", [100, img_size[1]], images=sample_image) # y = height
+ zoom(
+ "input_image", [100, img_size[1]], "square", images=sample_image
+ ) # y = height
with pytest.raises(ValueError, match="outside the image boundaries"):
- zoom("input_image", [100, -1], images=sample_image) # y < 0
+ zoom("input_image", [100, -1], "square", images=sample_image) # y < 0
def test_zoom_basic_centered(sample_image):
keypoint = [500, 400] # Center of 1000x800 image
- result = zoom("input_image", keypoint, images=sample_image)
+ result = zoom("input_image", keypoint, "square", images=sample_image)
assert isinstance(result, Image.Image)
assert result.size == (300, 300) # Always 300x300 output
def test_zoom_edge_case_top_left(sample_image):
keypoint = [50, 40] # Near top-left
- result = zoom("input_image", keypoint, images=sample_image)
+ result = zoom("input_image", keypoint, "square", images=sample_image)
assert isinstance(result, Image.Image)
assert result.size == (300, 300)
def test_zoom_edge_case_bottom_right(sample_image):
keypoint = [950, 750] # Near bottom-right
- result = zoom("input_image", keypoint, images=sample_image)
+ result = zoom("input_image", keypoint, "square", images=sample_image)
assert isinstance(result, Image.Image)
assert result.size == (300, 300)
@@ -389,7 +471,7 @@ def test_zoom_small_image(sample_image):
small_img = Image.new("RGB", (200, 150))
small_sample = {"input_image": small_img}
keypoint = [100, 75] # Center of small image
- result = zoom("input_image", keypoint, images=small_sample)
+ result = zoom("input_image", keypoint, "square", images=small_sample)
assert isinstance(result, Image.Image)
assert result.size == (300, 300) # Output is still 300x300 after resize
@@ -404,7 +486,7 @@ def test_zoom_content_check(sample_image):
filled_sample = {"input_image": img}
keypoint = [500, 400] # Center point within the red box
- result = zoom("input_image", keypoint, images=filled_sample)
+ result = zoom("input_image", keypoint, "square", images=filled_sample)
assert result.size == (300, 300)
# Check a central pixel in the output (should correspond to the red box)
# The original keypoint [500, 400] should map to the center [150, 150] in the 300x300 output
@@ -416,14 +498,25 @@ def test_zoom_content_check(sample_image):
def test_parse_valid_args():
- raw = {"name": "zoom", "image_name": "input_image", "keypoint": "[150, 250]"}
- expected = {"image_name": "input_image", "keypoint": [150, 250]}
+ raw = {
+ "name": "zoom",
+ "image_name": "input_image",
+ "keypoint": "[150, 250]",
+ "aspect_ratio_mode": "square",
+ }
+ expected = {
+ "image_name": "input_image",
+ "keypoint": [150, 250],
+ "aspect_ratio_mode": "square",
+ }
assert parse_zoom_args(raw) == expected
def test_parse_missing_key():
- raw = {"name": "zoom", "image_name": "input_image"}
- with pytest.raises(ValueError, match="Missing required arguments.*keypoint"):
+ raw = {"name": "zoom", "image_name": "input_image", "keypoint": "[150, 250]"}
+ with pytest.raises(
+ ValueError, match="Missing required arguments.*aspect_ratio_mode"
+ ):
parse_zoom_args(raw)
@@ -432,6 +525,7 @@ def test_parse_extra_key():
"name": "zoom",
"image_name": "input_image",
"keypoint": "[100, 100]",
+ "aspect_ratio_mode": "square",
"extra": "bad",
}
with pytest.raises(ValueError, match="Unexpected arguments.*extra"):
@@ -443,6 +537,7 @@ def test_parse_invalid_json():
"name": "zoom",
"image_name": "input_image",
"keypoint": "[100, 100",
+ "aspect_ratio_mode": "square",
} # Missing closing bracket
with pytest.raises(ValueError, match="Invalid JSON format for 'keypoint'"):
parse_zoom_args(raw)
@@ -453,6 +548,7 @@ def test_parse_keypoint_not_list():
"name": "zoom",
"image_name": "input_image",
"keypoint": '{"x": 100, "y": 100}',
+ "aspect_ratio_mode": "square",
}
with pytest.raises(ValueError, match="'keypoint' must be a JSON list"):
parse_zoom_args(raw)
@@ -503,18 +599,18 @@ def test_parse_keypoint_wrong_type():
try:
print(f"\nTesting center keypoint: {kp_center}")
- result_center = zoom("input_image", kp_center, images=images_dict)
+ result_center = zoom("input_image", kp_center, "square", images=images_dict)
print(f"Center zoom output size: {result_center.size}")
# result_center.show() # Optionally display the image
print(f"\nTesting edge keypoint: {kp_edge}")
- result_edge = zoom("input_image", kp_edge, images=images_dict)
+ result_edge = zoom("input_image", kp_edge, "square", images=images_dict)
print(f"Edge zoom output size: {result_edge.size}")
# result_edge.show() # Optionally display the image
print("\nTesting invalid keypoint:")
try:
- zoom("input_image", [1200, 100], images=images_dict)
+ zoom("input_image", [1200, 100], "square", images=images_dict)
except ValueError as e:
print(f"Caught expected error: {e}")
@@ -523,6 +619,7 @@ def test_parse_keypoint_wrong_type():
"name": "zoom",
"image_name": "input_image",
"keypoint": "[150, 250]",
+ "aspect_ratio_mode": "square",
}
parsed = parse_zoom_args(raw_valid)
print(f"Parsed valid args: {parsed}")
@@ -531,6 +628,7 @@ def test_parse_keypoint_wrong_type():
"name": "zoom",
"image_name": "input_image",
"keypoint": "[150.5, 250]",
+ "aspect_ratio_mode": "square",
}
try:
parse_zoom_args(raw_invalid)