Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
32054fb
revised object detection tool to use triton + yolo
sunildkumar Apr 29, 2025
610aeb9
Re-introduce object detection tool, but this time it is a YOLO instea…
sunildkumar Apr 29, 2025
af03c8c
ready to train
sunildkumar Apr 30, 2025
76542f5
start server was only local
sunildkumar Apr 30, 2025
2d6b792
working now
sunildkumar Apr 30, 2025
41dceb6
fix threading issues and bgr to rgb
sunildkumar Apr 30, 2025
2bedbac
up the weight to 1.0.
sunildkumar Apr 30, 2025
12fec83
setup to restart the run
sunildkumar Apr 30, 2025
0f8791f
the code technically works now, but it isn't pretty
sunildkumar Apr 30, 2025
950950c
this works but it is stupid slow, trying to move call out of training…
sunildkumar Apr 30, 2025
4d93020
its working! and its fast
sunildkumar Apr 30, 2025
f28e31c
remvoe fork thing
sunildkumar Apr 30, 2025
c471d8b
shuffle order of tools in system prompt
sunildkumar Apr 30, 2025
a74cd68
log metrics callback
sunildkumar Apr 30, 2025
c2c702f
reset the schedule
sunildkumar Apr 30, 2025
77df0b6
ready to start training again
sunildkumar Apr 30, 2025
56f804b
more robust way of catching
sunildkumar Apr 30, 2025
be43b07
generalized the eval script
sunildkumar May 1, 2025
bd99398
better eval script
sunildkumar May 1, 2025
9eab93e
implement the new combined correctness-and-tool-use reward
ROIM1998 May 1, 2025
1cbdd10
always return all 4 rewards but setting the schedules differently bas…
ROIM1998 May 2, 2025
0ed8429
add num_generations check to make sure the new reward sees the entire…
ROIM1998 May 2, 2025
35cd0a4
Merge pull request #67 from groundlight/aok_refactor_combined_reward
ROIM1998 May 2, 2025
39f6e9e
run name
sunildkumar May 2, 2025
d9ddbcd
try adding a short term incentive to use tools with new fancy reward
sunildkumar May 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ dependencies = [
"tiktoken>=0.9.0",
"openai>=1.65.4",
"opencv-python>=4.11.0.86",
"tritonclient[all]>=2.51.0",
"ultralytics>=8.3.120",
]

[tool.hatch.metadata]
Expand Down
10 changes: 5 additions & 5 deletions src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ def generate_r1_messages(example):

system_prompt = "REPLACED WITH TOOLS SYSTEM PROMPT"

choices_str = "These are the possible answers, you must choose one: "
choices_str = "Possible answers: "
for i, choice in enumerate(choices):
if i == len(choices) - 1:
choices_str += f"or {choice}."
else:
choices_str += f"{choice}, "

instruction = f"""
{question}
Question: {question}

{choices_str}
{choices_str} You must choose one to answer the question and place in <answer> tags.

You must inspect the input image and gather visual evidence. The image size is {image_size}.
You must inspect the input image to gather visual evidence. After you've collected evidence, combine that with your knowledge of the world to answer the question. You must consider all 4 possible answers when thinking through your reasoning. The image size is {image_size}.
"""

r1_messages = [
Expand All @@ -66,7 +66,7 @@ def generate_r1_messages(example):
"content": [
{
"type": "text",
"text": "\n<think> I'll collect as much visual evidence as possible from the image. First, I'll consider what region of the image to zoom in on to get the most information. Then, I'll review and consider the four possible answers. Then, I'll select the most likely answer based on the evidence and my knowledge of the world.",
"text": "\n<think> I'll collect as much visual evidence as possible from the image. Then, I'll consider the four possible answers. Finally, I'll select the most likely answer based on the evidence and my knowledge of the world.",
}
],
},
Expand Down
87 changes: 54 additions & 33 deletions src/r1_vlm/environments/multistep_vision_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def is_completed(self, messages: list[dict[str, str]], **kwargs: Any) -> bool:
@abstractmethod
def env_response(
self, messages: list[dict[str, str]], **kwargs: Any
) -> list[dict[str, Any]]:
) -> list[dict[str, Any]]:
pass

def prepare_data(self, *, inputs, processing_class):
Expand Down Expand Up @@ -96,12 +96,12 @@ def update_state(j, vlm_response):
state["prompt_ids"] = vlm_response.prompt_token_ids

# update the conversation with the model's response
state["messages"].append({
"role": "assistant",
"content": [
{"text": vlm_response.outputs[0].text, "type": "text"}
]
})
state["messages"].append(
{
"role": "assistant",
"content": [{"text": vlm_response.outputs[0].text, "type": "text"}],
}
)

# get token lengths of env response and new completion
total_prev_len = len(state["prompt_ids"]) + len(state["completion_ids"])
Expand All @@ -120,7 +120,7 @@ def update_state(j, vlm_response):
]

# if we are done, we mark the state as completed
# we do not want to truncate the completion ids here,
# we do not want to truncate the completion ids here,
# because the number of image tokens returned from the tools is variable
if (
self.is_completed(state["messages"])
Expand Down Expand Up @@ -153,7 +153,7 @@ def update_state(j, vlm_response):

for j, state in results:
states[j] = state

return states

def generate(
Expand All @@ -176,7 +176,7 @@ def generate(
}
for conversation in conversations
]

# main loop
while not all_completed:
states = self.step(states, vlm, custom_sp)
Expand All @@ -190,8 +190,7 @@ def generate(
"messages": completion_messages,
"mask": completion_mask,
}



def clean_messages_for_logging(messages):
cleaned = []
images = []
Expand All @@ -201,7 +200,10 @@ def clean_messages_for_logging(messages):
cleaned_content = []
for item in cleaned_message["content"]:
cleaned_item = item.copy()
if "image" in cleaned_item and cleaned_item["image"] is not None:
if (
"image" in cleaned_item
and cleaned_item["image"] is not None
):
images.append(cleaned_item["image"])
cleaned_item["image"] = "<PIL.Image object>"
cleaned_content.append(cleaned_item)
Expand All @@ -212,51 +214,60 @@ def clean_messages_for_logging(messages):
cleaned_messages, images = clean_messages_for_logging(states[0]["messages"])

self.logger.info(
"Full conversation 0:\n"
+ json.dumps(cleaned_messages, indent=4)
"Full conversation 0:\n" + json.dumps(cleaned_messages, indent=4)
)
for image in images:
imgcat.imgcat(image)

return output

@staticmethod
def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completions_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
'''
def preprocess_messages(
prompts_messages: list[list[dict[str, Any]]],
completions_messages: list[list[dict[str, Any]]],
) -> list[list[dict[str, Any]]]:
"""
1. Combines prompts and completion messages into full conversations
2. Removes all messages before the first assistant message, leaving only the completion
3. Merges elements of the completion that come from the same source and are text only

Args:
prompts: list of prompt conversations
completions_messages: list of completion conversations

Returns:
list of preprocessed completion conversations
'''
"""
# Combine prompts and completions into full conversations
combined_messages = []
for prompt_msgs, completion_msgs in zip(prompts_messages, completions_messages):
conversation = []
conversation.extend(prompt_msgs)
conversation.extend(completion_msgs)
combined_messages.append(conversation)

filtered_messages = []
for completion in combined_messages:
# find the index of the first assistant message
assistant_message_index = next((i for i, message in enumerate(completion) if message["role"] == "assistant"), None)

assistant_message_index = next(
(
i
for i, message in enumerate(completion)
if message["role"] == "assistant"
),
None,
)

if assistant_message_index is not None:
# keep only messages from the first assistant message onwards
filtered_messages.append(completion[assistant_message_index:])

merged_completions = []

for completion in filtered_messages:
merged_completion = []
current_message = None

for message in completion:
# If message has non-text content, add it as is
if any(item["type"] != "text" for item in message["content"]):
Expand All @@ -265,7 +276,7 @@ def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completion
current_message = None
merged_completion.append(message)
continue

# For text messages
if current_message and current_message["role"] == message["role"]:
# Merge text content
Expand All @@ -277,11 +288,21 @@ def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completion
merged_completion.append(current_message)
current_message = {
"role": message["role"],
"content": [{"type": "text", "text": message["content"][0]["text"]}]
"content": [
{"type": "text", "text": message["content"][0]["text"]}
],
}

if current_message:
merged_completion.append(current_message)
merged_completions.append(merged_completion)

return merged_completions

return merged_completions

def log_metrics(self, data):
"""
Callback for logging metrics. Can be implemented by subclasses.

Should return a dictionary of metrics (key = metric name, value = metric value)
"""
return {}
12 changes: 10 additions & 2 deletions src/r1_vlm/environments/simple_vision_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import imgcat
from qwen_vl_utils import process_vision_info
from verifiers import SimpleEnv
from vllm import LLM, SamplingParams # type: ignore

from r1_vlm.budget_forcing.budget_forcing import (
generate_completions_with_budget_forcing,
)
from verifiers import SimpleEnv


class SimpleVisionEnv(SimpleEnv):
Expand Down Expand Up @@ -93,7 +93,7 @@ def generate(
completions = vlm.generate(
vlm_inputs, sampling_params=custom_sp, use_tqdm=False
) # type: ignore

stop_reasons = [c.outputs[0].stop_reason for c in completions]
print(f"Stop reasons: {stop_reasons}")

Expand Down Expand Up @@ -166,6 +166,14 @@ def prepare_data(self, *, inputs, processing_class):

return conversations, texts, batch, vllm_inputs

def log_metrics(self, data):
"""
Callback for logging metrics. Can be implemented by subclasses.

Should return a dictionary of metrics (key = metric name, value = metric value)
"""
return {}


def prepare_inputs_for_env(*, inputs, processing_class):
"""
Expand Down
Loading