From 5c763274d45d3c23b42397d4f7a5c8689c5cdb6a Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 22:28:12 +0000 Subject: [PATCH 01/12] script for collecting demo rois --- src/r1_vlm/datasets/text_vqa/collect_demos.py | 279 ++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 src/r1_vlm/datasets/text_vqa/collect_demos.py diff --git a/src/r1_vlm/datasets/text_vqa/collect_demos.py b/src/r1_vlm/datasets/text_vqa/collect_demos.py new file mode 100644 index 0000000..5b79efa --- /dev/null +++ b/src/r1_vlm/datasets/text_vqa/collect_demos.py @@ -0,0 +1,279 @@ +import json +import os + +import matplotlib.pyplot as plt +import matplotlib.widgets as widgets +from datasets import load_dataset +from PIL import Image + +OUTPUT_FILE = "zoom_demos.jsonl" + + +# --- Image Resizing Logic (copied from text_vqa_tool_use_r1.py) --- +def resize_image(image): + """ + Resizes the image if its longer side exceeds 1024 pixels, + maintaining aspect ratio. + """ + # if the longer side is greater than 1024, resize it so the longer side is 1024 + if image.mode == "L": # Handle grayscale images + image = image.convert("RGB") + + if image.size[0] > 1024 or image.size[1] > 1024: + longer_side = max(image.size[0], image.size[1]) + image = image.resize( + ( + int(1024 * image.size[0] / longer_side), + int(1024 * image.size[1] / longer_side), + ), + Image.Resampling.LANCZOS, # Use LANCZOS for better quality + ) + return image + + +# --- Global State (simpler for this script) --- +current_example = None +original_image = None +original_size = None +displayed_image = None +displayed_size = None +selected_keypoint_original = None +selected_keypoint_display = None +marker_plot = None +dataset_iterator = None +demos_saved_count = 0 + +# --- Matplotlib UI Elements --- +fig, ax = plt.subplots(figsize=(10, 8)) # Make figure a bit larger +plt.subplots_adjust(bottom=0.25) # More space for controls and title +ax.axis("off") + +ax_save = plt.axes([0.81, 0.05, 0.15, 0.075]) +btn_save = widgets.Button(ax_save, "Save & Next") +btn_save.active = False # Initially disabled + +ax_undo = plt.axes([0.65, 0.05, 0.15, 0.075]) +btn_undo = widgets.Button(ax_undo, "Undo Click") +btn_undo.active = False # Initially disabled + +ax_skip = plt.axes([0.49, 0.05, 0.15, 0.075]) +btn_skip = widgets.Button(ax_skip, "Skip") + +# Text to show saved count +ax_count_text = plt.axes([0.05, 0.05, 0.4, 0.075]) +ax_count_text.axis("off") +count_text_obj = ax_count_text.text( + 0.5, + 0.5, + f"Saved: {demos_saved_count}", + ha="center", + va="center", + transform=ax_count_text.transAxes, +) + + +# --- Core Logic Functions --- +def update_saved_count_display(): + """Updates the saved count text on the plot.""" + global demos_saved_count + count_text_obj.set_text(f"Saved: {demos_saved_count}") + plt.draw() + + +def remove_marker(): + """Removes the keypoint marker from the plot if it exists.""" + global marker_plot + if marker_plot: + try: + marker_plot.remove() + except ValueError: # May have already been removed if plot cleared + pass + marker_plot = None + + +def display_current_example(): + """Displays the current image and question.""" + global displayed_image, displayed_size, marker_plot + if current_example is None: + return + + remove_marker() # Clear marker from previous example + + # Store original and prepare display image + original_image_pil = current_example["image"] + if original_image_pil.mode == "L": # Ensure RGB + original_image_pil = original_image_pil.convert("RGB") + + displayed_image = resize_image(original_image_pil.copy()) # Use resized for display + displayed_size = displayed_image.size + + ax.cla() # Clear axes before drawing new image + ax.imshow(displayed_image) + question_text = ( + f"Q ({current_example['question_id']}): {current_example['question']}" + ) + ax.set_title(question_text, wrap=True, fontsize=10) + ax.axis("off") + plt.draw() + + +def load_next_example(): + """Loads the next example from the dataset iterator and updates the UI.""" + global current_example, original_image, original_size + global selected_keypoint_original, selected_keypoint_display, marker_plot + global dataset_iterator + + if dataset_iterator is None: + print("Error: Dataset iterator not initialized.") + plt.close(fig) + return + + try: + current_example = next(dataset_iterator) + original_image = current_example["image"] + if original_image.mode == "L": # Ensure RGB for original too + original_image = original_image.convert("RGB") + original_size = original_image.size + + # Reset state for the new example + selected_keypoint_original = None + selected_keypoint_display = None + remove_marker() # Explicitly remove marker here too + + display_current_example() + + btn_save.active = False + btn_undo.active = False + plt.draw() # Update button states visually + + except StopIteration: + print("Finished processing all examples in the dataset split.") + plt.close(fig) + except Exception as e: + print(f"An unexpected error occurred loading the next example: {e}") + plt.close(fig) + + +def on_click(event): + """Handles mouse clicks on the image axes to select a keypoint.""" + global selected_keypoint_original, selected_keypoint_display, marker_plot + global original_size, displayed_size + + if event.inaxes != ax or original_size is None or displayed_size is None: + return # Click outside image axes or data not loaded + + # Get click coordinates relative to displayed image + x_disp, y_disp = int(event.xdata), int(event.ydata) + + # Map coordinates back to original image dimensions + scale_w = original_size[0] / displayed_size[0] + scale_h = original_size[1] / displayed_size[1] + x_orig = int(round(x_disp * scale_w)) + y_orig = int(round(y_disp * scale_h)) + + # Clamp coordinates to ensure they are within the original image bounds + x_orig = max(0, min(original_size[0] - 1, x_orig)) + y_orig = max(0, min(original_size[1] - 1, y_orig)) + + selected_keypoint_original = [x_orig, y_orig] + selected_keypoint_display = [x_disp, y_disp] + + # Update marker + remove_marker() + (marker_plot,) = ax.plot(x_disp, y_disp, "r+", markersize=12, markeredgewidth=2) + + # Enable Save and Undo buttons + btn_save.active = True + btn_undo.active = True + plt.draw() + + +def save_callback(event): + """Saves the current example's info and keypoint, then loads the next.""" + global demos_saved_count + if current_example is None or selected_keypoint_original is None: + print("Warning: Cannot save - no example loaded or keypoint selected.") + return + + data = { + "image_id": current_example["image_id"], + "question_id": current_example["question_id"], + "keypoint": selected_keypoint_original, + } + + try: + with open(OUTPUT_FILE, "a") as f: + f.write(json.dumps(data) + "\n") + print(f"Saved: QID {data['question_id']} -> {data['keypoint']}") + demos_saved_count += 1 + update_saved_count_display() + load_next_example() + except IOError as e: + print(f"Error saving data to {OUTPUT_FILE}: {e}") + except Exception as e: + print(f"An unexpected error occurred during save: {e}") + + +def undo_callback(event): + """Clears the selected keypoint and removes the marker.""" + global selected_keypoint_original, selected_keypoint_display + remove_marker() + selected_keypoint_original = None + selected_keypoint_display = None + btn_save.active = False + btn_undo.active = False + plt.draw() + + +def skip_callback(event): + """Loads the next example without saving.""" + print("Skipped.") + load_next_example() + + +# --- Initialization and Execution --- +def main(): + global dataset_iterator, demos_saved_count + + print("Loading TextVQA 'train' split...") + try: + dataset = load_dataset("lmms-lab/textvqa", split="train") + dataset_iterator = iter(dataset) + print("Dataset loaded.") + except Exception as e: + print(f"Failed to load dataset: {e}") + return + + # Count existing demos if file exists + if os.path.exists(OUTPUT_FILE): + try: + with open(OUTPUT_FILE, "r") as f: + demos_saved_count = sum(1 for line in f) + print( + f"Found {demos_saved_count} existing demonstrations in {OUTPUT_FILE}." + ) + except Exception as e: + print( + f"Warning: Could not read existing demo count from {OUTPUT_FILE}: {e}" + ) + + # Connect callbacks + btn_save.on_clicked(save_callback) + btn_undo.on_clicked(undo_callback) + btn_skip.on_clicked(skip_callback) + fig.canvas.mpl_connect("button_press_event", on_click) + + # Load the first example and show the plot + print("Loading first example...") + update_saved_count_display() # Show initial count + load_next_example() + + if current_example: # Only show plot if first load succeeded + print("Starting labeling interface. Close the plot window to exit.") + plt.show() + else: + print("Could not load the first example. Exiting.") + + +if __name__ == "__main__": + main() From cb01600dc48308ccdebcfb9bb508e447f9f3b0bf Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 22:36:20 +0000 Subject: [PATCH 02/12] improve gui --- src/r1_vlm/datasets/text_vqa/collect_demos.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/r1_vlm/datasets/text_vqa/collect_demos.py b/src/r1_vlm/datasets/text_vqa/collect_demos.py index 5b79efa..3097d42 100644 --- a/src/r1_vlm/datasets/text_vqa/collect_demos.py +++ b/src/r1_vlm/datasets/text_vqa/collect_demos.py @@ -109,10 +109,15 @@ def display_current_example(): ax.cla() # Clear axes before drawing new image ax.imshow(displayed_image) + answers = current_example.get("answers", []) # Get answers, default to empty list + answers_str = ", ".join(answers) if answers else "N/A" question_text = ( - f"Q ({current_example['question_id']}): {current_example['question']}" + f"Q ({current_example['question_id']}): {current_example['question']}\n" + f"Answers: {answers_str}" ) - ax.set_title(question_text, wrap=True, fontsize=10) + ax.set_title( + question_text, wrap=True, fontsize=9 + ) # Slightly smaller font for more text ax.axis("off") plt.draw() From 9b3a1cba5eec48071c31cb5fbb9b980b6413b68b Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 22:39:31 +0000 Subject: [PATCH 03/12] shuffle --- src/r1_vlm/datasets/text_vqa/collect_demos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/r1_vlm/datasets/text_vqa/collect_demos.py b/src/r1_vlm/datasets/text_vqa/collect_demos.py index 3097d42..23cad67 100644 --- a/src/r1_vlm/datasets/text_vqa/collect_demos.py +++ b/src/r1_vlm/datasets/text_vqa/collect_demos.py @@ -242,7 +242,7 @@ def main(): print("Loading TextVQA 'train' split...") try: - dataset = load_dataset("lmms-lab/textvqa", split="train") + dataset = load_dataset("lmms-lab/textvqa", split="train").shuffle(seed=42) dataset_iterator = iter(dataset) print("Dataset loaded.") except Exception as e: From a64b7e461de8ac1867283d2867ac474c11e19ff8 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 15:52:13 -0700 Subject: [PATCH 04/12] collected some demos --- src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl diff --git a/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl b/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl new file mode 100644 index 0000000..cb2d4d0 --- /dev/null +++ b/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl @@ -0,0 +1,101 @@ +{"image_id": "0054c91397f2fe05", "question_id": 0, "keypoint": [352, 312]} +{"image_id": "005635e119b9f32f", "question_id": 1, "keypoint": [561, 322]} +{"image_id": "005635e119b9f32f", "question_id": 2, "keypoint": [80, 326]} +{"image_id": "00685bc495504d61", "question_id": 3, "keypoint": [119, 1010]} +{"image_id": "00685bc495504d61", "question_id": 4, "keypoint": [159, 567]} +{"image_id": "006d10667d17b924", "question_id": 5, "keypoint": [445, 461]} +{"image_id": "006d10667d17b924", "question_id": 6, "keypoint": [442, 140]} +{"image_id": "00c359f294f7dcd9", "question_id": 7, "keypoint": [373, 399]} +{"image_id": "00c359f294f7dcd9", "question_id": 8, "keypoint": [382, 400]} +{"image_id": "0122c5279a501df2", "question_id": 9, "keypoint": [907, 247]} +{"image_id": "0122c5279a501df2", "question_id": 10, "keypoint": [300, 275]} +{"image_id": "015c537f722d115e", "question_id": 11, "keypoint": [646, 362]} +{"image_id": "0202faf23a9aae11", "question_id": 12, "keypoint": [169, 601]} +{"image_id": "02b2daa20b00f5d3", "question_id": 13, "keypoint": [162, 195]} +{"image_id": "02b2daa20b00f5d3", "question_id": 14, "keypoint": [540, 322]} +{"image_id": "02da7a42e971e3cb", "question_id": 15, "keypoint": [520, 457]} +{"image_id": "02da7a42e971e3cb", "question_id": 16, "keypoint": [518, 460]} +{"image_id": "d709940b064b2ea7", "question_id": 4782, "keypoint": [525, 487]} +{"image_id": "d7e15efeee988b35", "question_id": 23953, "keypoint": [360, 631]} +{"image_id": "258dcb6be2766e6b", "question_id": 10485, "keypoint": [54, 296]} +{"image_id": "ae0e286ea6301919", "question_id": 6787, "keypoint": [854, 483]} +{"image_id": "a9085b21fafeb7e0", "question_id": 12304, "keypoint": [357, 519]} +{"image_id": "0876802af9ca4f53", "question_id": 2069, "keypoint": [295, 537]} +{"image_id": "51153c1aeeba2142", "question_id": 21859, "keypoint": [568, 533]} +{"image_id": "bb4f9f6fa8be3784", "question_id": 23499, "keypoint": [573, 303]} +{"image_id": "0cc86f2bc0023406", "question_id": 18672, "keypoint": [290, 534]} +{"image_id": "e47a56d412750c3e", "question_id": 14447, "keypoint": [383, 360]} +{"image_id": "3109e0ab52305483", "question_id": 5725, "keypoint": [336, 263]} +{"image_id": "65277ae6a3eb6c8a", "question_id": 14311, "keypoint": [478, 644]} +{"image_id": "3678e10d1ee15709", "question_id": 2707, "keypoint": [800, 190]} +{"image_id": "04f9cc3db7d8da24", "question_id": 1303, "keypoint": [646, 456]} +{"image_id": "99c2dd0295e97a04", "question_id": 30730, "keypoint": [92, 417]} +{"image_id": "fcab06f91d7da18d", "question_id": 33237, "keypoint": [414, 591]} +{"image_id": "61a5eb30c602a678", "question_id": 3308, "keypoint": [461, 508]} +{"image_id": "3bc293efa0203314", "question_id": 18512, "keypoint": [370, 471]} +{"image_id": "4a8830e36c78f5e7", "question_id": 14061, "keypoint": [553, 267]} +{"image_id": "beed9557550278a3", "question_id": 9481, "keypoint": [545, 327]} +{"image_id": "4c858f86738a72fd", "question_id": 1534, "keypoint": [666, 348]} +{"image_id": "104bf2461f9962a8", "question_id": 5445, "keypoint": [452, 532]} +{"image_id": "5b41dbe96cd05104", "question_id": 26950, "keypoint": [689, 363]} +{"image_id": "8dad2e89e4ad9839", "question_id": 6536, "keypoint": [58, 189]} +{"image_id": "ad899b711c4e5613", "question_id": 6784, "keypoint": [119, 406]} +{"image_id": "0927271c7f9f3b9e", "question_id": 30035, "keypoint": [153, 324]} +{"image_id": "cbe00b9c28494af5", "question_id": 17311, "keypoint": [444, 538]} +{"image_id": "2d4f3317cfe522c8", "question_id": 21210, "keypoint": [438, 598]} +{"image_id": "8568bc5fc7abbb33", "question_id": 16729, "keypoint": [611, 114]} +{"image_id": "a70b30b1806e1226", "question_id": 23157, "keypoint": [732, 293]} +{"image_id": "008b8556a73601ae", "question_id": 17883, "keypoint": [384, 830]} +{"image_id": "1b8bf6437a221a10", "question_id": 1412, "keypoint": [461, 222]} +{"image_id": "4f8d13bce21b6ada", "question_id": 34364, "keypoint": [397, 359]} +{"image_id": "1a47acd8e0d07441", "question_id": 28380, "keypoint": [540, 258]} +{"image_id": "143d9adc0368b1dd", "question_id": 17992, "keypoint": [508, 249]} +{"image_id": "bba917b91c4c46dd", "question_id": 12361, "keypoint": [961, 402]} +{"image_id": "0204574848b0bf05", "question_id": 29549, "keypoint": [363, 521]} +{"image_id": "083ea2d88ee71e24", "question_id": 31564, "keypoint": [211, 763]} +{"image_id": "2f6621dfaa4e4672", "question_id": 27459, "keypoint": [745, 452]} +{"image_id": "d929fd569dbae5fe", "question_id": 4809, "keypoint": [516, 882]} +{"image_id": "20bb636326491022", "question_id": 20971, "keypoint": [503, 63]} +{"image_id": "8df181c0899df120", "question_id": 3843, "keypoint": [446, 204]} +{"image_id": "923545f7c0785426", "question_id": 34026, "keypoint": [505, 864]} +{"image_id": "89f64c747cf016ec", "question_id": 29772, "keypoint": [531, 543]} +{"image_id": "ba0c5d07cd0803fa", "question_id": 25480, "keypoint": [495, 92]} +{"image_id": "6ce0bbbee320e928", "question_id": 12134, "keypoint": [134, 686]} +{"image_id": "a760819a4cd02bc5", "question_id": 32792, "keypoint": [479, 661]} +{"image_id": "04b99754c51e6092", "question_id": 1296, "keypoint": [219, 377]} +{"image_id": "061d8e57615d98c4", "question_id": 12598, "keypoint": [140, 536]} +{"image_id": "f3d9c7c15556813a", "question_id": 24298, "keypoint": [890, 89]} +{"image_id": "023463e93c2fa36e", "question_id": 24935, "keypoint": [631, 39]} +{"image_id": "6b1bad9e14c54eff", "question_id": 13167, "keypoint": [242, 498]} +{"image_id": "14297a09cd0704df", "question_id": 2226, "keypoint": [444, 480]} +{"image_id": "9f9e1b7cdf8cc9d5", "question_id": 25407, "keypoint": [401, 49]} +{"image_id": "0012836b7cb7e1a8", "question_id": 29531, "keypoint": [822, 462]} +{"image_id": "2e975f0231b49855", "question_id": 28647, "keypoint": [522, 286]} +{"image_id": "ac32d325dbe4cf76", "question_id": 17033, "keypoint": [674, 28]} +{"image_id": "a6628a4b35d3382b", "question_id": 6722, "keypoint": [571, 426]} +{"image_id": "8c48baf8904eeca4", "question_id": 18305, "keypoint": [413, 333]} +{"image_id": "2124bb3f0620f8b0", "question_id": 11024, "keypoint": [578, 460]} +{"image_id": "5091e1f3b5a27a0f", "question_id": 14142, "keypoint": [392, 342]} +{"image_id": "eacf4d74fad54bac", "question_id": 31014, "keypoint": [883, 279]} +{"image_id": "3c570608f258bcfd", "question_id": 21503, "keypoint": [394, 793]} +{"image_id": "cb2aa50349915d64", "question_id": 23771, "keypoint": [325, 147]} +{"image_id": "92d76e2fdaf6f07d", "question_id": 3918, "keypoint": [639, 914]} +{"image_id": "06bc833b6794cb6d", "question_id": 18174, "keypoint": [399, 407]} +{"image_id": "0fc054810d8ab304", "question_id": 8424, "keypoint": [274, 37]} +{"image_id": "423165577f3f52ce", "question_id": 30439, "keypoint": [575, 348]} +{"image_id": "12bc08d2fce21a45", "question_id": 10451, "keypoint": [919, 286]} +{"image_id": "e101f9601e0d971c", "question_id": 667, "keypoint": [89, 416]} +{"image_id": "5894cb4708630e4c", "question_id": 27527, "keypoint": [202, 624]} +{"image_id": "df824b1152781395", "question_id": 20093, "keypoint": [862, 995]} +{"image_id": "fecd1e75b8deeca0", "question_id": 17756, "keypoint": [182, 183]} +{"image_id": "d81bf036569df61d", "question_id": 8999, "keypoint": [743, 125]} +{"image_id": "813f257ac2f2bfe2", "question_id": 25334, "keypoint": [141, 724]} +{"image_id": "dc9d9de141534563", "question_id": 27148, "keypoint": [494, 372]} +{"image_id": "56853543d7ae1235", "question_id": 13047, "keypoint": [362, 77]} +{"image_id": "0919842b32156e39", "question_id": 9127, "keypoint": [650, 264]} +{"image_id": "7112b28a96993e28", "question_id": 9994, "keypoint": [218, 677]} +{"image_id": "5ddd2a8a3c2328d6", "question_id": 19626, "keypoint": [693, 457]} +{"image_id": "765bb8ab1e6a0eff", "question_id": 15176, "keypoint": [536, 93]} +{"image_id": "4b7044e659d47df3", "question_id": 31928, "keypoint": [322, 107]} +{"image_id": "476198cff16798d6", "question_id": 25176, "keypoint": [460, 420]} +{"image_id": "e0b856b9fe929346", "question_id": 4906, "keypoint": [433, 561]} From 57b133750e2d2dc5c8cf2245f4f683270a682b79 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 23:35:04 +0000 Subject: [PATCH 05/12] add keypoint demos to dataset --- .../datasets/text_vqa/text_vqa_tool_use_r1.py | 108 +++++++++++++++++- 1 file changed, 102 insertions(+), 6 deletions(-) diff --git a/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py b/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py index d1b2408..7bb27f3 100644 --- a/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py +++ b/src/r1_vlm/datasets/text_vqa/text_vqa_tool_use_r1.py @@ -1,4 +1,7 @@ -from datasets import Dataset, DatasetDict, load_dataset +import json + +from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset +from PIL import Image from tqdm import tqdm from r1_vlm.datasets.utils import IMAGE_PLACEHOLDER @@ -12,7 +15,8 @@ def resize_image(image): ( int(1024 * image.size[0] / longer_side), int(1024 * image.size[1] / longer_side), - ) + ), + Image.Resampling.LANCZOS, ) return image @@ -79,7 +83,10 @@ def generate_r1_messages(example): def create_r1_text_vqa_tool_use_dataset( - max_examples_per_split: int | None = None, splits_to_process: list[str] = None + max_examples_per_split: int | None = None, + splits_to_process: list[str] = None, + zoom_demos_path: str + | None = "/millcreek/home/sunil/r1_vlm_bumbershoot2/r1_vlm/src/r1_vlm/datasets/text_vqa/zoom_demos.jsonl", ): dataset = load_dataset("lmms-lab/textvqa") @@ -91,6 +98,29 @@ def create_r1_text_vqa_tool_use_dataset( if split not in valid_splits: raise ValueError(f"Invalid split: {split}") + # --- Load Zoom Demos if path provided --- + zoom_demos_lookup = None + if zoom_demos_path and "train" in splits_to_process: + print(f"Loading zoom demonstrations from: {zoom_demos_path}") + try: + with open(zoom_demos_path, "r") as f: + demos = [json.loads(line) for line in f] + zoom_demos_lookup = { + demo["question_id"]: demo["keypoint"] for demo in demos + } + print(f"Loaded {len(zoom_demos_lookup)} zoom demonstrations.") + except FileNotFoundError: + print( + f"Warning: Zoom demos file not found at {zoom_demos_path}. Skipping augmentation." + ) + zoom_demos_lookup = None + except Exception as e: + print( + f"Warning: Error loading zoom demos from {zoom_demos_path}: {e}. Skipping augmentation." + ) + zoom_demos_lookup = None + # --- End Load Zoom Demos --- + processed_datasets = {} for split in splits_to_process: print(f"Processing {split} split...") @@ -103,13 +133,79 @@ def create_r1_text_vqa_tool_use_dataset( if len(examples) >= max_examples_per_split: break - processed_datasets[split] = Dataset.from_list(examples) + processed_dataset = Dataset.from_list(examples) + + if split == "train" and zoom_demos_lookup is not None: + print(f"Augmenting '{split}' split with zoom keypoints...") + + def add_zoom_keypoint(example): + keypoint = zoom_demos_lookup.get(example["question_id"], None) + return {"zoom_keypoint": keypoint} + + # Add the column + processed_dataset = processed_dataset.map( + add_zoom_keypoint + ) # Added num_proc for potential speedup + print(f"Added 'zoom_keypoint' column to '{split}' split.") + + # --- REORDERING START --- + print( + f"Reordering '{split}' split to place examples with keypoints first..." + ) + # Filter examples with and without keypoints + with_keypoint_ds = processed_dataset.filter( + lambda example: example["zoom_keypoint"] is not None + ) + without_keypoint_ds = processed_dataset.filter( + lambda example: example["zoom_keypoint"] is None + ) + + # --- Shuffle the 'without_keypoint_ds' --- + print("Shuffling examples without keypoints...") + without_keypoint_ds = without_keypoint_ds.shuffle( + seed=42 + ) # Added shuffle with seed + print("Shuffled examples without keypoints.") + # --- End Shuffle --- + + # Concatenate them in the desired order + processed_dataset = concatenate_datasets( + [with_keypoint_ds, without_keypoint_ds] + ) + print(f"Reordered '{split}' split.") + # --- REORDERING END --- + + processed_datasets[split] = processed_dataset return DatasetDict(processed_datasets) if __name__ == "__main__": dataset = create_r1_text_vqa_tool_use_dataset( - max_examples_per_split=10, splits_to_process=["train"] + max_examples_per_split=10000, splits_to_process=["train"] + ) + + # count how many examples have zoom_keypoint + num_with_keypoint = len( + [ + example + for example in dataset["train"] + if example["zoom_keypoint"] is not None + ] ) - print(dataset["train"][0]) + print(f"Number of examples with zoom_keypoint: {num_with_keypoint}") + + # Verify the first few examples have keypoints (if any exist) + print("\nVerifying first few examples have keypoints (if available):") + for i in range(min(5, num_with_keypoint)): + print( + f"Example {i}: QID={dataset['train'][i]['question_id']}, Keypoint={dataset['train'][i]['zoom_keypoint']}" + ) + + # Verify an example after the keypoint block has None (if applicable) + if num_with_keypoint < len(dataset["train"]): + print("\nVerifying example after the keypoint block:") + idx_after = num_with_keypoint # First example expected to have None + print( + f"Example {idx_after}: QID={dataset['train'][idx_after]['question_id']}, Keypoint={dataset['train'][idx_after]['zoom_keypoint']}" + ) From 09a8ea2f42d5df02301fa305fed9db32d4dd21c7 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 23:45:00 +0000 Subject: [PATCH 06/12] skeleton of new tool reward --- .../tool_use_text_vqa_env.py | 39 ++++++++++++++++++- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py index 344947f..bb47c6b 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py @@ -3,8 +3,6 @@ from datasets import Dataset from transformers import AutoProcessor -from trl.trainer.grpo_trainer import RewardFunc -from verifiers.parsers import XMLParser from r1_vlm.datasets.text_vqa.text_vqa_tool_use_r1 import ( create_r1_text_vqa_tool_use_dataset, @@ -14,6 +12,8 @@ from r1_vlm.environments.tool_vision_env import ToolArgParser, ToolVisionEnv from r1_vlm.tools.tool_prompts import SINGLE_TOOL_PROMPT_TEMPLATE from r1_vlm.tools.zoom import parse_zoom_args, zoom +from trl.trainer.grpo_trainer import RewardFunc +from verifiers.parsers import XMLParser # Implementing the exact preprocessing steps that are used in text vqa evaluation # https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L11 @@ -341,6 +341,11 @@ def get_reward_weights(self) -> list[float]: # consistent high reward for getting the answer right schedule = 1.0 reward_weights.append(schedule) + + elif reward_function.__name__ == "zoom_keypoint_reward_func": + # consistent high reward for using the zoom keypoint + schedule = 1.0 + reward_weights.append(schedule) else: raise ValueError( f"Unknown reward function: {reward_function.__name__} encountered in get_reward_weights" @@ -548,10 +553,40 @@ def correct_answer_reward_func( return correctness_rewards + def zoom_keypoint_reward_func( + prompts, completions, completions_messages, **kwargs + ): + merged_completion_conversations = MultistepVisionEnv.preprocess_messages( + prompts_messages=prompts, completions_messages=completions_messages + ) + + print(f"{prompts[0]=}") + print(f"{completions[0]=}") + print(f"{completions_messages[0]=}") + + # check if the example has a zoom keypoint + zoom_keypoint = kwargs.get("zoom_keypoint", None) + if zoom_keypoint is None: + return [0.0] * len(merged_completion_conversations) + + print(f"zoom_keypoint: {zoom_keypoint}") + + # determine if each completion uses the zoom keypoint: + # 1. parse each assistant message for a tool call + # 2. check if the tool call is a zoom call: "name: zoom" + # 3. check for keypoint: [x, y] in the tool call, and extract it + # 4. If there is > 1 zoom tool call with a keypoint in a conversation, we use the first one + # 5. Convert both keypoints to normalized coordinates (0-1) + # 5. compute the normalized euclidean distance between the keypoint in the tool call and the zoom keypoint + # 6. Reward = 1 - distance + + return [0.0] * len(merged_completion_conversations) + return [ format_reward_func, tool_execution_reward_func, correct_answer_reward_func, + zoom_keypoint_reward_func, ] def log_metrics(self, conversations, completions_text, completion_messages): From 35bb8943ccf2d3506718057ad02f2c7c97caf4b7 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 17:58:23 -0700 Subject: [PATCH 07/12] ready to try training on examples with zoom demos --- .../tool_use_text_vqa_env.py | 103 ++++++++++++++---- .../tool_use_text_vqa_train.py | 8 +- 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py index bb47c6b..8b0a541 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py @@ -1,8 +1,11 @@ +import math import re from typing import Any, Callable from datasets import Dataset from transformers import AutoProcessor +from trl.trainer.grpo_trainer import RewardFunc +from verifiers.parsers import XMLParser from r1_vlm.datasets.text_vqa.text_vqa_tool_use_r1 import ( create_r1_text_vqa_tool_use_dataset, @@ -12,8 +15,6 @@ from r1_vlm.environments.tool_vision_env import ToolArgParser, ToolVisionEnv from r1_vlm.tools.tool_prompts import SINGLE_TOOL_PROMPT_TEMPLATE from r1_vlm.tools.zoom import parse_zoom_args, zoom -from trl.trainer.grpo_trainer import RewardFunc -from verifiers.parsers import XMLParser # Implementing the exact preprocessing steps that are used in text vqa evaluation # https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L11 @@ -292,8 +293,13 @@ def get_dataset( output_datasets[split] = dataset_split + # only keep the examples with keypoints from the train set for testing if "train" in splits: - output_datasets["train"].shuffle() + print("Filtering train set to only include examples with zoom keypoints.") + output_datasets["train"] = output_datasets["train"].filter( + lambda x: x["zoom_keypoint"] is not None + ) + print(f"After filtering, {len(output_datasets['train'])} examples remain.") return output_datasets @@ -553,6 +559,17 @@ def correct_answer_reward_func( return correctness_rewards + def extract_keypoint(text): + # Pattern to match [x, y] where x and y are integers, allowing multiline + pattern = r"\[\s*(\d+)\s*,\s*(\d+)\s*\]" + match = re.search(pattern, text, re.DOTALL) + + if match: + x = int(match.group(1)) + y = int(match.group(2)) + return x, y + return None + def zoom_keypoint_reward_func( prompts, completions, completions_messages, **kwargs ): @@ -560,27 +577,75 @@ def zoom_keypoint_reward_func( prompts_messages=prompts, completions_messages=completions_messages ) - print(f"{prompts[0]=}") - print(f"{completions[0]=}") - print(f"{completions_messages[0]=}") + # extract the image size from the prompt + prompt = prompts[0] + image_content = prompt[1]["content"][1] + image = image_content["image"] + image_width = image.width + image_height = image.height - # check if the example has a zoom keypoint - zoom_keypoint = kwargs.get("zoom_keypoint", None) - if zoom_keypoint is None: + # extract the zoom keypoint from the completion, if it exists + conv_keypoints = [] + for conv in merged_completion_conversations: + keypoint_found = False + for msg in conv: + if msg["role"] == "assistant": + parsed = self.parser.parse(msg["content"][0]["text"]) + if hasattr(parsed, "tool") and parsed.tool is not None: + tool_content = parsed.tool + if "name: zoom" in tool_content: + conv_keypoint = extract_keypoint(tool_content) + conv_keypoints.append(conv_keypoint) + keypoint_found = True + break + # if no tool call with a zoom keypoint is found, append None + if not keypoint_found: + conv_keypoints.append(None) + + print(f"conv_keypoints: {conv_keypoints}") + + # check if the example has a zoom keypoint - this is a list of Group size tuples, each tuple is identical + zoom_keypoints = kwargs.get("zoom_keypoint", None) + + # if this example does not have a zoom keypoint, return 0 reward for all completions as there's no way to verify the result + if zoom_keypoints is None: return [0.0] * len(merged_completion_conversations) - print(f"zoom_keypoint: {zoom_keypoint}") + zoom_keypoint = zoom_keypoints[0] + normalized_zoom_keypoint = ( + zoom_keypoint[0] / image_width, + zoom_keypoint[1] / image_height, + ) + print(f"normalized_zoom_keypoint: {normalized_zoom_keypoint}") + + normalized_conv_keypoints = [] + for conv_keypoint in conv_keypoints: + if conv_keypoint is None: + normalized_conv_keypoints.append(None) + else: + normalized_conv_keypoints.append( + ( + conv_keypoint[0] / image_width, + conv_keypoint[1] / image_height, + ) + ) - # determine if each completion uses the zoom keypoint: - # 1. parse each assistant message for a tool call - # 2. check if the tool call is a zoom call: "name: zoom" - # 3. check for keypoint: [x, y] in the tool call, and extract it - # 4. If there is > 1 zoom tool call with a keypoint in a conversation, we use the first one - # 5. Convert both keypoints to normalized coordinates (0-1) - # 5. compute the normalized euclidean distance between the keypoint in the tool call and the zoom keypoint - # 6. Reward = 1 - distance + rewards = [] + for normalized_conv_keypoint in normalized_conv_keypoints: + if normalized_conv_keypoint is None: + rewards.append(0.0) + else: + # compute the normalized euclidean distance between the keypoint in the tool call and the zoom keypoint + distance = math.sqrt( + (normalized_conv_keypoint[0] - normalized_zoom_keypoint[0]) ** 2 + + (normalized_conv_keypoint[1] - normalized_zoom_keypoint[1]) + ** 2 + ) + # clip the reward at 0.0, as this value could be negative (1 - root 2) worst case + reward = max(1 - distance, 0.0) + rewards.append(reward) - return [0.0] * len(merged_completion_conversations) + return rewards return [ format_reward_func, diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py index acd678d..38761f6 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py @@ -100,7 +100,7 @@ def train(): print("loaded env") # TODO: increase max examples per split - datasets = vf_env.get_dataset(splits=["train"]) + datasets = vf_env.get_dataset(splits=["train"], max_examples_per_split=100) train_dataset = datasets["train"] rubric = vf_env.get_rubric() @@ -110,7 +110,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-text-vqa-tool-use-no-tool-reward-may4-3B", + output_dir="vlm-r1-text-vqa-distance-reward-may4-3B", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, @@ -120,7 +120,7 @@ def train(): logging_steps=1, save_steps=50, save_total_limit=10, - num_train_epochs=1, + num_train_epochs=100, per_device_train_batch_size=1, num_generations=3, gradient_accumulation_steps=4, @@ -158,6 +158,8 @@ def train(): env=vf_env, peft_config=peft_config, guided_regex=None, + # shuffling is disabled because we want to start with the keypoint examples. + shuffle_dataset=False, ) trainer.train() From 2b22c81a06e453d6890a222de6d4e3dbe8bb15d6 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 18:02:46 -0700 Subject: [PATCH 08/12] full trainset --- .../tool_use_text_vqa_env/tool_use_text_vqa_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py index 38761f6..3fad0e2 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py @@ -100,7 +100,7 @@ def train(): print("loaded env") # TODO: increase max examples per split - datasets = vf_env.get_dataset(splits=["train"], max_examples_per_split=100) + datasets = vf_env.get_dataset(splits=["train"]) train_dataset = datasets["train"] rubric = vf_env.get_rubric() From f4c26967c948cb7db76b7932bc5f14ef0bc8525d Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 4 May 2025 20:49:23 -0700 Subject: [PATCH 09/12] Enhance zoom functionality by adding aspect ratio support and updating related methods. Modify image cropping logic to accommodate different aspect ratios, and update tests and argument parsing to include aspect ratio mode. Adjust output directory naming for clarity. --- .../environments/multistep_vision_env.py | 4 +- .../tool_use_text_vqa_env.py | 18 +- .../tool_use_text_vqa_train.py | 2 +- src/r1_vlm/tools/zoom.py | 268 ++++++++++++------ 4 files changed, 204 insertions(+), 88 deletions(-) diff --git a/src/r1_vlm/environments/multistep_vision_env.py b/src/r1_vlm/environments/multistep_vision_env.py index 825cf7a..9cc6a72 100644 --- a/src/r1_vlm/environments/multistep_vision_env.py +++ b/src/r1_vlm/environments/multistep_vision_env.py @@ -218,7 +218,9 @@ def clean_messages_for_logging(messages): ) for image in images: try: - imgcat.imgcat(image) + image_copy = image.copy() + image_copy.convert("RGB") + imgcat.imgcat(image_copy) except Exception as e: print( f"Caught failed imgcat call for image. As this is just a debugging print, we will not raise an error. {e}" diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py index 8b0a541..09816b2 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py @@ -660,10 +660,16 @@ def log_metrics(self, conversations, completions_text, completion_messages): completions_with_tool_use = 0 completions_with_zoom_use = 0 + use_horizontal_zoom = 0 + use_vertical_zoom = 0 + use_square_zoom = 0 for completion in completions_text: tool_use_regex = r"(.*?)" zoom_use_string = "name: zoom" + horizontal_zoom_use_string = "aspect_ratio_mode: horizontal" + vertical_zoom_use_string = "aspect_ratio_mode: vertical" + square_zoom_use_string = "aspect_ratio_mode: square" tool_matches = re.findall(tool_use_regex, completion, re.DOTALL) if tool_matches: completions_with_tool_use += 1 @@ -671,8 +677,15 @@ def log_metrics(self, conversations, completions_text, completion_messages): if zoom_use_string in tool_content: completions_with_zoom_use += 1 + if horizontal_zoom_use_string in tool_content: + use_horizontal_zoom += 1 + if vertical_zoom_use_string in tool_content: + use_vertical_zoom += 1 + if square_zoom_use_string in tool_content: + use_square_zoom += 1 + print( - f"There are {len(completions_text)} completions, {completions_with_tool_use} of which attempt to use a tool, {completions_with_zoom_use} of which attempt to use zoom" + f"There are {len(completions_text)} completions, \n {completions_with_tool_use} of which attempt to use a tool, \n {completions_with_zoom_use} of which attempt to use zoom. \n {use_horizontal_zoom} of which attempt to use horizontal zoom, \n {use_vertical_zoom} of which attempt to use vertical zoom, \n {use_square_zoom} of which attempt to use square zoom" ) num_completions = len(completions_text) @@ -682,6 +695,9 @@ def log_metrics(self, conversations, completions_text, completion_messages): return { "tool_use_proportion": tool_use_proportion, "zoom_use_proportion": zoom_use_proportion, + "use_horizontal_zoom": use_horizontal_zoom, + "use_vertical_zoom": use_vertical_zoom, + "use_square_zoom": use_square_zoom, } diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py index 3fad0e2..eafa0ce 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py @@ -110,7 +110,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-text-vqa-distance-reward-may4-3B", + output_dir="vlm-r1-text-vqa-distance-reward-zoom-aspect-ratio-may4-3B", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, diff --git a/src/r1_vlm/tools/zoom.py b/src/r1_vlm/tools/zoom.py index 534481d..74809dd 100644 --- a/src/r1_vlm/tools/zoom.py +++ b/src/r1_vlm/tools/zoom.py @@ -1,5 +1,6 @@ # generic zoom tool import json +import math # Added import from typing import List, Tuple import pytest @@ -9,31 +10,33 @@ def calculate_crop_coordinates( - keypoint: List[int], image_size: Tuple[int, int], target_size: int = 300 + keypoint: List[int], + image_size: Tuple[int, int], + aspect_ratio_mode: str, + base_target_size: int = 300, ) -> Tuple[int, int, int, int]: """ - Calculates the crop coordinates for a square box centered around a keypoint. + Calculates the crop coordinates for a box centered around a keypoint, + with a specified aspect ratio. + The box aims for an area roughly equal to base_target_size * base_target_size. If the target box extends beyond the image boundaries, the box is shifted - to stay within the image, maintaining the target size. The keypoint will - no longer be centered in this case. - - The returned coordinates are suitable for use with PIL's Image.crop method, - following the (left, upper, right, lower) format where 'right' and 'lower' - are exclusive. + to stay within the image, maintaining the target size and aspect ratio. + The keypoint will no longer be centered in this case. Args: keypoint: A list or tuple [x, y] representing the desired center coordinates. image_size: A tuple (width, height) of the original image. - target_size: The desired width and height of the square crop box. + aspect_ratio_mode: The desired aspect ratio ("square", "horizontal", "vertical"). + base_target_size: The base dimension used for area calculation. Returns: A tuple (crop_left, crop_upper, crop_right, crop_lower) defining the region to crop from the original image. Raises: - ValueError: If the keypoint is outside the image boundaries or inputs - are invalid types/formats. + ValueError: If inputs are invalid types/formats, keypoint is outside + the image boundaries, or aspect_ratio_mode is invalid. """ # --- Input Validation --- if ( @@ -54,8 +57,13 @@ def calculate_crop_coordinates( "Error:image_size must be a tuple of two positive integers (width, height)" ) - if not isinstance(target_size, int) or target_size <= 0: - raise ValueError("Error: target_size must be a positive integer") + if not isinstance(base_target_size, int) or base_target_size <= 0: + raise ValueError("Error: base_target_size must be a positive integer") + + if aspect_ratio_mode not in ["square", "horizontal", "vertical"]: + raise ValueError( + "Error: aspect_ratio_mode must be 'square', 'horizontal', or 'vertical'" + ) x, y = keypoint width, height = image_size @@ -66,22 +74,40 @@ def calculate_crop_coordinates( f"(width={width}, height={height})" ) + # --- Calculate Target Crop Dimensions based on Aspect Ratio --- + sqrt2 = math.sqrt(2) + if aspect_ratio_mode == "square": + target_crop_width = base_target_size + target_crop_height = base_target_size + elif aspect_ratio_mode == "horizontal": # 2:1 + target_crop_width = round(base_target_size * sqrt2) + target_crop_height = round(base_target_size / sqrt2) + else: # aspect_ratio_mode == "vertical" (1:2) + target_crop_width = round(base_target_size / sqrt2) + target_crop_height = round(base_target_size * sqrt2) + + # Ensure dimensions are at least 1 + target_crop_width = max(1, target_crop_width) + target_crop_height = max(1, target_crop_height) + # --- Calculate Crop Box --- - half_size = target_size // 2 + # Calculate half-sizes (integer division) + half_width = target_crop_width // 2 + half_height = target_crop_height // 2 # Calculate the ideal top-left corner to center the box - ideal_x_min = x - half_size - ideal_y_min = y - half_size + ideal_x_min = x - half_width + ideal_y_min = y - half_height # Initialize the actual top-left corner crop_left = ideal_x_min crop_upper = ideal_y_min # Adjust if the box extends beyond the right or bottom boundaries - if crop_left + target_size > width: - crop_left = width - target_size - if crop_upper + target_size > height: - crop_upper = height - target_size + if crop_left + target_crop_width > width: + crop_left = width - target_crop_width + if crop_upper + target_crop_height > height: + crop_upper = height - target_crop_height # Adjust if the box extends beyond the left or top boundaries if crop_left < 0: @@ -91,13 +117,11 @@ def calculate_crop_coordinates( # Calculate the final right and lower bounds for PIL crop (exclusive) # Clamp the right/lower bounds to the image dimensions, important if - # target_size > image width/height. - crop_right = min(crop_left + target_size, width) - crop_lower = min(crop_upper + target_size, height) + # target crop dimensions > image width/height. + crop_right = min(crop_left + target_crop_width, width) + crop_lower = min(crop_upper + target_crop_height, height) # Ensure left/upper are also clamped in case target_size > image dimensions - # which might lead to negative values after boundary adjustments. - # The previous checks already handle this, but being explicit doesn't hurt. crop_left = max(0, crop_left) crop_upper = max(0, crop_upper) @@ -107,29 +131,48 @@ def calculate_crop_coordinates( def zoom( image_name: str, keypoint: list[int], + aspect_ratio_mode: str, **kwargs, ) -> Image.Image: """ - Returns an image zoomed in on the specified keypoint. + Returns an image zoomed in on the specified keypoint with a given aspect ratio. This is useful to see a region of interest with higher clarity, like for reading text or identifying objects. + The choice of aspect ration is helpful depending on the task. + If you are looking at horizontal text or some other wide region, you might use "horizontal" aspect ratio. + If you are looking at vertical text or some other tall region, you might use "vertical" aspect ratio. + If you are looking at a region more generally, you might use "square" aspect ratio. + Args: image_name: str, the name of the image to zoom in on. Can only be called on the "input_image". keypoint: list[int], the [x, y] coordinates of the point to center the zoom on. Must be within the image boundaries. + aspect_ratio_mode: str, the desired aspect ratio for the zoom window. Must be one of "square", "horizontal", "vertical". + "square" is 1:1, "horizontal" is 2:1 (wider), "vertical" is 1:2 (taller). Returns: - The image zoomed in on the specified keypoint. + The image zoomed in on the specified keypoint with the requested aspect ratio. Examples: name: zoom image_name: input_image + aspect_ratio_mode: square keypoint: [500, 400] + + + + name: zoom + image_name: input_image + aspect_ratio_mode: horizontal + keypoint: [23, 44] + name: zoom image_name: input_image - keypoint: [50, 40] + aspect_ratio_mode: vertical + keypoint: [723, 461] + """ # get and validate the image @@ -146,27 +189,50 @@ def zoom( f"Error: Image {image_name} is not the input_image. This tool can only be called on the input_image." ) - width, height = image.size - # size we will crop to - target_size = 300 - - # we will resize the image to the larger dimension of the input image, or 400 if it's a smaller image - resize_size = max(width, height, 400) + original_width, original_height = image.size + # base size for cropping area calculation + base_crop_target_size = 300 - # Validate keypoint and calculate crop box using the helper function - # ValueError will be raised by the helper if keypoint is invalid - crop_box = calculate_crop_coordinates(keypoint, (width, height), target_size) + # Validate aspect_ratio_mode and keypoint, then calculate crop box + crop_box = calculate_crop_coordinates( + keypoint, + (original_width, original_height), + aspect_ratio_mode, + base_target_size=base_crop_target_size, + ) # crop the image to the calculated box cropped_image = image.crop(crop_box) - - # Resize the cropped image to the target size - output_image = cropped_image.resize( - (resize_size, resize_size), Image.Resampling.LANCZOS + crop_w, crop_h = cropped_image.size + + # If crop resulted in zero size (e.g., invalid crop box calculation edge case) + if crop_w == 0 or crop_h == 0: + raise ValueError("Error: Cropped image has zero dimension") + + # Determine the base dimension for resizing the output - use original image dims or 400 min + base_resize_dim = max(original_width, original_height, 400) + + # Calculate the final output size, preserving aspect ratio of the cropped image + # Scale so the LARGER dimension of the output matches base_resize_dim + if crop_w >= crop_h: # Wider or square crop + output_w = base_resize_dim + output_h = round(output_w * crop_h / crop_w) # Maintain aspect ratio + else: # Taller crop + output_h = base_resize_dim + output_w = round(output_h * crop_w / crop_h) # Maintain aspect ratio + + # Ensure minimum dimension is 1 + output_w = max(1, output_w) + output_h = max(1, output_h) + output_size = (output_w, output_h) + + # Resize the cropped image to the calculated output size, preserving aspect ratio + output_image = cropped_image.resize(output_size, Image.Resampling.LANCZOS) + + print( + f"Zoom tool cropped to {crop_w}x{crop_h}, resized output to: {output_image.size}" ) - print(f"Zoom tool output image size: {output_image.size}") - return output_image @@ -174,7 +240,7 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs: """ Parses raw string arguments for the zoom tool (keypoint version). - Expects keys: 'name', 'image_name', 'keypoint'. + Expects keys: 'name', 'image_name', 'keypoint', 'aspect_ratio_mode'. Converts 'keypoint' from a JSON string to a list of integers. Detailed validation of values (e.g., keypoint coordinates) is deferred to the zoom function itself. @@ -184,14 +250,20 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs: Returns: A dictionary containing the arguments with basic type conversions applied, - ready for the zoom function. Keys: 'image_name', 'keypoint'. + ready for the zoom function. Keys: 'image_name', 'keypoint', 'aspect_ratio_mode'. Raises: ValueError: If required keys are missing, extra keys are present, - or basic type conversion fails (e.g., 'keypoint' is not valid JSON - or doesn't result in a list of two integers). + basic type conversion fails (e.g., 'keypoint' is not valid JSON + or doesn't result in a list of two integers), or + 'aspect_ratio_mode' has an invalid value. """ - required_keys = {"name", "image_name", "keypoint"} + required_keys = { + "name", + "image_name", + "keypoint", + "aspect_ratio_mode", + } # Added aspect_ratio_mode actual_keys = set(raw_args.keys()) # 1. Check for Missing Keys @@ -224,9 +296,12 @@ def parse_zoom_args(raw_args: RawToolArgs) -> TypedToolArgs: raise ValueError( "Error: Both elements in 'keypoint' list must be integers." ) - typed_args["keypoint"] = keypoint_list + # Validate aspect_ratio_mode + aspect_mode = raw_args["aspect_ratio_mode"] + typed_args["aspect_ratio_mode"] = aspect_mode + except json.JSONDecodeError: raise ValueError( f"Error: Invalid JSON format for 'keypoint': '{raw_args['keypoint']}'" @@ -254,7 +329,7 @@ def test_calc_centered(sample_image): img_size = sample_image["input_image"].size kp = [500, 400] ts = 300 - box = calculate_crop_coordinates(kp, img_size, ts) + box = calculate_crop_coordinates(kp, img_size, "square", ts) assert box == (350, 250, 650, 550) @@ -262,7 +337,7 @@ def test_calc_top_left(sample_image): img_size = sample_image["input_image"].size kp = [50, 40] ts = 300 - box = calculate_crop_coordinates(kp, img_size, ts) + box = calculate_crop_coordinates(kp, img_size, "square", ts) assert box == (0, 0, 300, 300) @@ -270,7 +345,7 @@ def test_calc_bottom_right(sample_image): img_size = sample_image["input_image"].size kp = [950, 750] ts = 300 - box = calculate_crop_coordinates(kp, img_size, ts) + box = calculate_crop_coordinates(kp, img_size, "square", ts) assert box == (700, 500, 1000, 800) @@ -278,7 +353,7 @@ def test_calc_right_middle(sample_image): img_size = sample_image["input_image"].size kp = [950, 400] ts = 300 - box = calculate_crop_coordinates(kp, img_size, ts) + box = calculate_crop_coordinates(kp, img_size, "square", ts) assert box == (700, 250, 1000, 550) @@ -286,7 +361,7 @@ def test_calc_small_image(sample_image): small_img_size = (200, 150) kp = [100, 75] ts = 300 - box = calculate_crop_coordinates(kp, small_img_size, ts) + box = calculate_crop_coordinates(kp, small_img_size, "square", ts) assert box == (0, 0, 200, 150) @@ -294,30 +369,30 @@ def test_calc_invalid_keypoint_coords(sample_image): img_size = sample_image["input_image"].size ts = 300 with pytest.raises(ValueError, match="outside the image boundaries"): - calculate_crop_coordinates([1000, 400], img_size, ts) + calculate_crop_coordinates([1000, 400], img_size, "square", ts) with pytest.raises(ValueError, match="outside the image boundaries"): - calculate_crop_coordinates([-1, 400], img_size, ts) + calculate_crop_coordinates([-1, 400], img_size, "square", ts) with pytest.raises(ValueError, match="outside the image boundaries"): - calculate_crop_coordinates([500, 800], img_size, ts) + calculate_crop_coordinates([500, 800], img_size, "square", ts) with pytest.raises(ValueError, match="outside the image boundaries"): - calculate_crop_coordinates([500, -1], img_size, ts) + calculate_crop_coordinates([500, -1], img_size, "square", ts) def test_calc_invalid_input_types(): with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - calculate_crop_coordinates("[100, 100]", (500, 500), 300) + calculate_crop_coordinates("[100, 100]", (500, 500), "square", 300) with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - calculate_crop_coordinates([100], (500, 500), 300) + calculate_crop_coordinates([100], (500, 500), "square", 300) with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - calculate_crop_coordinates([100.0, 100], (500, 500), 300) + calculate_crop_coordinates([100.0, 100], (500, 500), "square", 300) with pytest.raises(ValueError, match="image_size must be a tuple"): - calculate_crop_coordinates([100, 100], [500, 500], 300) + calculate_crop_coordinates([100, 100], [500, 500], "square", 300) with pytest.raises(ValueError, match="image_size must be a tuple"): - calculate_crop_coordinates([100, 100], (500, 0), 300) + calculate_crop_coordinates([100, 100], (500, 0), "square", 300) with pytest.raises(ValueError, match="target_size must be a positive integer"): - calculate_crop_coordinates([100, 100], (500, 500), 0) + calculate_crop_coordinates([100, 100], (500, 500), "square", 0) with pytest.raises(ValueError, match="target_size must be a positive integer"): - calculate_crop_coordinates([100, 100], (500, 500), -100) + calculate_crop_coordinates([100, 100], (500, 500), "square", -100) # --- Tests for zoom function --- @@ -325,7 +400,7 @@ def test_calc_invalid_input_types(): def test_zoom_invalid_image_name(sample_image): with pytest.raises(ValueError, match="not found"): - zoom("nonexistent_image", [100, 100], images=sample_image) + zoom("nonexistent_image", [100, 100], "square", images=sample_image) # Add test for incorrect image name usage (e.g., "test_image" instead of "input_image") @@ -334,6 +409,7 @@ def test_zoom_incorrect_image_name_usage(sample_image): zoom( "test_image", [100, 100], + "square", images={"test_image": sample_image["input_image"]}, ) @@ -342,44 +418,50 @@ def test_zoom_incorrect_image_name_usage(sample_image): def test_zoom_invalid_keypoint_format(sample_image): with pytest.raises(ValueError, match="keypoint must be a list or tuple"): zoom( - "input_image", "[100, 100]", images=sample_image + "input_image", "[100, 100]", "square", images=sample_image ) # Pass string instead of list with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - zoom("input_image", [100], images=sample_image) # Pass list with 1 element + zoom( + "input_image", [100], "square", images=sample_image + ) # Pass list with 1 element with pytest.raises(ValueError, match="keypoint must be a list or tuple"): - zoom("input_image", [100.0, 100], images=sample_image) # Pass float + zoom("input_image", [100.0, 100], "square", images=sample_image) # Pass float # Test keypoint outside image bounds (should be caught by calculate_crop_coordinates) def test_zoom_keypoint_out_of_bounds(sample_image): img_size = sample_image["input_image"].size with pytest.raises(ValueError, match="outside the image boundaries"): - zoom("input_image", [img_size[0], 100], images=sample_image) # x = width + zoom( + "input_image", [img_size[0], 100], "square", images=sample_image + ) # x = width with pytest.raises(ValueError, match="outside the image boundaries"): - zoom("input_image", [-1, 100], images=sample_image) # x < 0 + zoom("input_image", [-1, 100], "square", images=sample_image) # x < 0 with pytest.raises(ValueError, match="outside the image boundaries"): - zoom("input_image", [100, img_size[1]], images=sample_image) # y = height + zoom( + "input_image", [100, img_size[1]], "square", images=sample_image + ) # y = height with pytest.raises(ValueError, match="outside the image boundaries"): - zoom("input_image", [100, -1], images=sample_image) # y < 0 + zoom("input_image", [100, -1], "square", images=sample_image) # y < 0 def test_zoom_basic_centered(sample_image): keypoint = [500, 400] # Center of 1000x800 image - result = zoom("input_image", keypoint, images=sample_image) + result = zoom("input_image", keypoint, "square", images=sample_image) assert isinstance(result, Image.Image) assert result.size == (300, 300) # Always 300x300 output def test_zoom_edge_case_top_left(sample_image): keypoint = [50, 40] # Near top-left - result = zoom("input_image", keypoint, images=sample_image) + result = zoom("input_image", keypoint, "square", images=sample_image) assert isinstance(result, Image.Image) assert result.size == (300, 300) def test_zoom_edge_case_bottom_right(sample_image): keypoint = [950, 750] # Near bottom-right - result = zoom("input_image", keypoint, images=sample_image) + result = zoom("input_image", keypoint, "square", images=sample_image) assert isinstance(result, Image.Image) assert result.size == (300, 300) @@ -389,7 +471,7 @@ def test_zoom_small_image(sample_image): small_img = Image.new("RGB", (200, 150)) small_sample = {"input_image": small_img} keypoint = [100, 75] # Center of small image - result = zoom("input_image", keypoint, images=small_sample) + result = zoom("input_image", keypoint, "square", images=small_sample) assert isinstance(result, Image.Image) assert result.size == (300, 300) # Output is still 300x300 after resize @@ -404,7 +486,7 @@ def test_zoom_content_check(sample_image): filled_sample = {"input_image": img} keypoint = [500, 400] # Center point within the red box - result = zoom("input_image", keypoint, images=filled_sample) + result = zoom("input_image", keypoint, "square", images=filled_sample) assert result.size == (300, 300) # Check a central pixel in the output (should correspond to the red box) # The original keypoint [500, 400] should map to the center [150, 150] in the 300x300 output @@ -416,14 +498,25 @@ def test_zoom_content_check(sample_image): def test_parse_valid_args(): - raw = {"name": "zoom", "image_name": "input_image", "keypoint": "[150, 250]"} - expected = {"image_name": "input_image", "keypoint": [150, 250]} + raw = { + "name": "zoom", + "image_name": "input_image", + "keypoint": "[150, 250]", + "aspect_ratio_mode": "square", + } + expected = { + "image_name": "input_image", + "keypoint": [150, 250], + "aspect_ratio_mode": "square", + } assert parse_zoom_args(raw) == expected def test_parse_missing_key(): - raw = {"name": "zoom", "image_name": "input_image"} - with pytest.raises(ValueError, match="Missing required arguments.*keypoint"): + raw = {"name": "zoom", "image_name": "input_image", "keypoint": "[150, 250]"} + with pytest.raises( + ValueError, match="Missing required arguments.*aspect_ratio_mode" + ): parse_zoom_args(raw) @@ -432,6 +525,7 @@ def test_parse_extra_key(): "name": "zoom", "image_name": "input_image", "keypoint": "[100, 100]", + "aspect_ratio_mode": "square", "extra": "bad", } with pytest.raises(ValueError, match="Unexpected arguments.*extra"): @@ -443,6 +537,7 @@ def test_parse_invalid_json(): "name": "zoom", "image_name": "input_image", "keypoint": "[100, 100", + "aspect_ratio_mode": "square", } # Missing closing bracket with pytest.raises(ValueError, match="Invalid JSON format for 'keypoint'"): parse_zoom_args(raw) @@ -453,6 +548,7 @@ def test_parse_keypoint_not_list(): "name": "zoom", "image_name": "input_image", "keypoint": '{"x": 100, "y": 100}', + "aspect_ratio_mode": "square", } with pytest.raises(ValueError, match="'keypoint' must be a JSON list"): parse_zoom_args(raw) @@ -503,18 +599,18 @@ def test_parse_keypoint_wrong_type(): try: print(f"\nTesting center keypoint: {kp_center}") - result_center = zoom("input_image", kp_center, images=images_dict) + result_center = zoom("input_image", kp_center, "square", images=images_dict) print(f"Center zoom output size: {result_center.size}") # result_center.show() # Optionally display the image print(f"\nTesting edge keypoint: {kp_edge}") - result_edge = zoom("input_image", kp_edge, images=images_dict) + result_edge = zoom("input_image", kp_edge, "square", images=images_dict) print(f"Edge zoom output size: {result_edge.size}") # result_edge.show() # Optionally display the image print("\nTesting invalid keypoint:") try: - zoom("input_image", [1200, 100], images=images_dict) + zoom("input_image", [1200, 100], "square", images=images_dict) except ValueError as e: print(f"Caught expected error: {e}") @@ -523,6 +619,7 @@ def test_parse_keypoint_wrong_type(): "name": "zoom", "image_name": "input_image", "keypoint": "[150, 250]", + "aspect_ratio_mode": "square", } parsed = parse_zoom_args(raw_valid) print(f"Parsed valid args: {parsed}") @@ -531,6 +628,7 @@ def test_parse_keypoint_wrong_type(): "name": "zoom", "image_name": "input_image", "keypoint": "[150.5, 250]", + "aspect_ratio_mode": "square", } try: parse_zoom_args(raw_invalid) From bbd9ea13ae5c70832713ecd7baf81391e18f2a9c Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Mon, 5 May 2025 07:44:19 -0700 Subject: [PATCH 10/12] ready for stage 2 of ft on the entire train set --- .../tool_use_text_vqa_env.py | 16 ++++++++++------ .../tool_use_text_vqa_train.py | 8 ++++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py index 09816b2..198849c 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py @@ -293,13 +293,17 @@ def get_dataset( output_datasets[split] = dataset_split - # only keep the examples with keypoints from the train set for testing + # only keep the examples with keypoints from the train set for initial finetuning + # if "train" in splits: + # print("Filtering train set to only include examples with zoom keypoints.") + # output_datasets["train"] = output_datasets["train"].filter( + # lambda x: x["zoom_keypoint"] is not None + # ) + # print(f"After filtering, {len(output_datasets['train'])} examples remain.") + + # shuffle the train set if "train" in splits: - print("Filtering train set to only include examples with zoom keypoints.") - output_datasets["train"] = output_datasets["train"].filter( - lambda x: x["zoom_keypoint"] is not None - ) - print(f"After filtering, {len(output_datasets['train'])} examples remain.") + output_datasets["train"].shuffle() return output_datasets diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py index eafa0ce..51bb215 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_train.py @@ -90,8 +90,12 @@ def find_target_linear_names( def train(): + checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-text-vqa-distance-reward-zoom-aspect-ratio-may4-3B/checkpoint-500" + model, peft_config, processor, model_config, gradient_checkpointing = ( - load_model_and_processor(gradient_checkpointing=True, use_peft=False) + load_model_and_processor( + model_name_or_path=checkpoint, gradient_checkpointing=True, use_peft=False + ) ) print("loaded model") @@ -110,7 +114,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-text-vqa-distance-reward-zoom-aspect-ratio-may4-3B", + output_dir="vlm-r1-text-vqa-post-demo-ft-may5-3B", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, From 293c1ff85777d5415b41e9b867155fda817131ca Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Mon, 5 May 2025 08:31:59 -0700 Subject: [PATCH 11/12] fix reward for the case where no gt is present. --- .../tool_use_text_vqa_env/tool_use_text_vqa_env.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py index 198849c..873da8f 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py @@ -616,6 +616,11 @@ def zoom_keypoint_reward_func( return [0.0] * len(merged_completion_conversations) zoom_keypoint = zoom_keypoints[0] + + # if the zoom keypoint is not present, return 0 reward for all completions as there's no way to verify the result + if zoom_keypoint is None: + return [0.0] * len(merged_completion_conversations) + normalized_zoom_keypoint = ( zoom_keypoint[0] / image_width, zoom_keypoint[1] / image_height, From ac0061ac6a43ec461b606acd5e027a4472292b27 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Mon, 5 May 2025 08:32:17 -0700 Subject: [PATCH 12/12] shuffle properly --- .../environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py index 873da8f..d041d94 100644 --- a/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py +++ b/src/r1_vlm/environments/tool_use_text_vqa_env/tool_use_text_vqa_env.py @@ -303,7 +303,7 @@ def get_dataset( # shuffle the train set if "train" in splits: - output_datasets["train"].shuffle() + output_datasets["train"] = output_datasets["train"].shuffle() return output_datasets